Upload 72 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- bindevaluator.py +182 -0
- classifier_code/__init__.py +0 -0
- classifier_code/binding_affinity_unpooled.py +321 -0
- classifier_code/binding_affinity_unpooled_2.py +356 -0
- classifier_code/half_life.py +65 -0
- classifier_code/hemolysis_wt.py +101 -0
- classifier_code/nonfouling_wt.py +98 -0
- classifier_code/solubility_wt.py +98 -0
- flow_matching/__init__.py +7 -0
- flow_matching/loss/__init__.py +11 -0
- flow_matching/loss/generalized_loss.py +83 -0
- flow_matching/path/__init__.py +22 -0
- flow_matching/path/affine.py +260 -0
- flow_matching/path/geodesic.py +100 -0
- flow_matching/path/mixture.py +117 -0
- flow_matching/path/path.py +61 -0
- flow_matching/path/path_sample.py +53 -0
- flow_matching/path/scheduler/__init__.py +29 -0
- flow_matching/path/scheduler/schedule_transform.py +148 -0
- flow_matching/path/scheduler/scheduler.py +199 -0
- flow_matching/solver/__init__.py +18 -0
- flow_matching/solver/discrete_solver.py +428 -0
- flow_matching/solver/ode_solver.py +197 -0
- flow_matching/solver/riemannian_ode_solver.py +261 -0
- flow_matching/solver/solver.py +17 -0
- flow_matching/solver/utils.py +19 -0
- flow_matching/utils/__init__.py +17 -0
- flow_matching/utils/categorical_sampler.py +23 -0
- flow_matching/utils/manifolds/__init__.py +18 -0
- flow_matching/utils/manifolds/manifold.py +93 -0
- flow_matching/utils/manifolds/sphere.py +45 -0
- flow_matching/utils/manifolds/torus.py +28 -0
- flow_matching/utils/manifolds/utils.py +45 -0
- flow_matching/utils/model_wrapper.py +43 -0
- flow_matching/utils/multi_guidance.py +216 -0
- flow_matching/utils/multi_guidance_cnp.py +217 -0
- flow_matching/utils/utils.py +90 -0
- models/classifier.py +116 -0
- models/enhancer_models.py +215 -0
- models/peptide_classifiers.py +751 -0
- models/peptide_models.py +359 -0
- modules/bindevaluator_modules/__init__.py +3 -0
- modules/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/bindevaluator_modules/__pycache__/__init__.cpython-38.pyc +0 -0
- modules/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc +0 -0
- modules/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc +0 -0
- modules/bindevaluator_modules/__pycache__/dataloaders.cpython-38.pyc +0 -0
- modules/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc +0 -0
- modules/bindevaluator_modules/__pycache__/layers.cpython-310.pyc +0 -0
- modules/bindevaluator_modules/__pycache__/layers.cpython-38.pyc +0 -0
bindevaluator.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from datasets import load_from_disk
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, accuracy_score
|
| 7 |
+
from argparse import ArgumentParser
|
| 8 |
+
import os
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import pdb
|
| 12 |
+
|
| 13 |
+
from modules.bindevaluator_modules import * # Import your model and other necessary classes/functions here
|
| 14 |
+
|
| 15 |
+
def parse_motifs(motif: str) -> list:
|
| 16 |
+
parts = motif.split(',')
|
| 17 |
+
result = []
|
| 18 |
+
|
| 19 |
+
for part in parts:
|
| 20 |
+
part = part.strip()
|
| 21 |
+
if '-' in part:
|
| 22 |
+
start, end = map(int, part.split('-'))
|
| 23 |
+
result.extend(range(start, end + 1))
|
| 24 |
+
else:
|
| 25 |
+
result.append(int(part))
|
| 26 |
+
|
| 27 |
+
result = [pos-1 for pos in result]
|
| 28 |
+
print(f'Target Motifs: {result}')
|
| 29 |
+
return torch.tensor(result)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class PeptideModel(pl.LightningModule):
|
| 33 |
+
def __init__(self, n_layers, d_model, d_hidden, n_head,
|
| 34 |
+
d_k, d_v, d_inner, dropout=0.2,
|
| 35 |
+
learning_rate=0.00001, max_epochs=15, kl_weight=1):
|
| 36 |
+
super(PeptideModel, self).__init__()
|
| 37 |
+
|
| 38 |
+
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 39 |
+
# freeze all the esm_model parameters
|
| 40 |
+
for param in self.esm_model.parameters():
|
| 41 |
+
param.requires_grad = False
|
| 42 |
+
|
| 43 |
+
self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
|
| 44 |
+
n_head, d_k, d_v, d_inner, dropout=dropout)
|
| 45 |
+
|
| 46 |
+
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
| 47 |
+
d_k, d_v, dropout=dropout)
|
| 48 |
+
|
| 49 |
+
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
| 50 |
+
|
| 51 |
+
self.output_projection_prot = nn.Linear(d_model, 1)
|
| 52 |
+
|
| 53 |
+
self.learning_rate = learning_rate
|
| 54 |
+
self.max_epochs = max_epochs
|
| 55 |
+
self.kl_weight = kl_weight
|
| 56 |
+
|
| 57 |
+
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
|
| 58 |
+
self.historical_memory = 0.9
|
| 59 |
+
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
|
| 60 |
+
|
| 61 |
+
def forward(self, binder_tokens, target_tokens):
|
| 62 |
+
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
|
| 63 |
+
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
|
| 64 |
+
|
| 65 |
+
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
| 66 |
+
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
| 67 |
+
protein_sequence)
|
| 68 |
+
|
| 69 |
+
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
| 70 |
+
|
| 71 |
+
prot_enc = self.final_ffn(prot_enc)
|
| 72 |
+
|
| 73 |
+
prot_enc = self.output_projection_prot(prot_enc)
|
| 74 |
+
|
| 75 |
+
return prot_enc
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def calculate_score(target_sequence, binder_sequence, model, args):
|
| 79 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 81 |
+
anchor_tokens = tokenizer(target_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000)
|
| 82 |
+
positive_tokens = tokenizer(binder_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000)
|
| 83 |
+
|
| 84 |
+
anchor_tokens['attention_mask'][0][0] = 0
|
| 85 |
+
anchor_tokens['attention_mask'][0][-1] = 0
|
| 86 |
+
positive_tokens['attention_mask'][0][0] = 0
|
| 87 |
+
positive_tokens['attention_mask'][0][-1] = 0
|
| 88 |
+
|
| 89 |
+
target_tokens = {'input_ids': anchor_tokens["input_ids"].to(device),
|
| 90 |
+
'attention_mask': anchor_tokens["attention_mask"].to(device)}
|
| 91 |
+
binder_tokens = {'input_ids': positive_tokens['input_ids'].to(device),
|
| 92 |
+
'attention_mask': positive_tokens['attention_mask'].to(device)}
|
| 93 |
+
|
| 94 |
+
model.eval()
|
| 95 |
+
|
| 96 |
+
# pdb.set_trace()
|
| 97 |
+
|
| 98 |
+
prediction = model(binder_tokens, target_tokens).squeeze(-1)[0][1:-1]
|
| 99 |
+
prediction = torch.sigmoid(prediction)
|
| 100 |
+
|
| 101 |
+
return prediction, model.classification_threshold
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def compute_metrics(true_residues, predicted_residues, length):
|
| 105 |
+
# Initialize the true and predicted lists with 0
|
| 106 |
+
true_list = [0] * length
|
| 107 |
+
predicted_list = [0] * length
|
| 108 |
+
|
| 109 |
+
# Set the values to 1 based on the provided lists
|
| 110 |
+
for index in true_residues:
|
| 111 |
+
true_list[index] = 1
|
| 112 |
+
for index in predicted_residues:
|
| 113 |
+
predicted_list[index] = 1
|
| 114 |
+
|
| 115 |
+
# Compute the metrics
|
| 116 |
+
accuracy = accuracy_score(true_list, predicted_list)
|
| 117 |
+
f1 = f1_score(true_list, predicted_list)
|
| 118 |
+
mcc = matthews_corrcoef(true_list, predicted_list)
|
| 119 |
+
|
| 120 |
+
return accuracy, f1, mcc
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def main():
|
| 124 |
+
parser = ArgumentParser()
|
| 125 |
+
parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt',
|
| 126 |
+
help="File containing initial params", type=str)
|
| 127 |
+
parser.add_argument("-batch_size", type=int, default=32, help="Batch size")
|
| 128 |
+
parser.add_argument("-lr", type=float, default=1e-3)
|
| 129 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
|
| 130 |
+
parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
|
| 131 |
+
parser.add_argument("-d_hidden", type=int, default=128, help="Dimension of CNN block")
|
| 132 |
+
parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
|
| 133 |
+
parser.add_argument("-d_inner", type=int, default=64)
|
| 134 |
+
parser.add_argument("-target", type=str)
|
| 135 |
+
parser.add_argument("-binder", type=str)
|
| 136 |
+
parser.add_argument("-gt", type=str, default=None)
|
| 137 |
+
parser.add_argument("-motifs", type=str, default=None)
|
| 138 |
+
args = parser.parse_args()
|
| 139 |
+
|
| 140 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 141 |
+
|
| 142 |
+
model = PeptideModel.load_from_checkpoint(args.sm,
|
| 143 |
+
n_layers=args.n_layers,
|
| 144 |
+
d_model=args.d_model,
|
| 145 |
+
d_hidden=args.d_hidden,
|
| 146 |
+
n_head=args.n_head,
|
| 147 |
+
d_k=64,
|
| 148 |
+
d_v=128,
|
| 149 |
+
d_inner=64).to(device)
|
| 150 |
+
|
| 151 |
+
prediction, _ = calculate_score(args.target, args.binder, model, args)
|
| 152 |
+
# print(prediction)
|
| 153 |
+
# print(model.classification_threshold)
|
| 154 |
+
|
| 155 |
+
binding_site = []
|
| 156 |
+
for i in range(len(prediction)):
|
| 157 |
+
if prediction[i] >= 0.5:
|
| 158 |
+
binding_site.append(i)
|
| 159 |
+
|
| 160 |
+
print("Prediction: ", binding_site)
|
| 161 |
+
prediction = prediction.detach().cpu().tolist()
|
| 162 |
+
np.set_printoptions(precision=2, suppress=True)
|
| 163 |
+
print(prediction)
|
| 164 |
+
|
| 165 |
+
if args.motifs is not None:
|
| 166 |
+
motifs = parse_motifs(args.motifs).tolist()
|
| 167 |
+
print(f"Motif Score: {torch.sum(prediction[motifs]) / len(motifs)}")
|
| 168 |
+
|
| 169 |
+
if args.gt is not None:
|
| 170 |
+
L = len(args.target)
|
| 171 |
+
# print(L)
|
| 172 |
+
gt = parse_motifs(args.gt)
|
| 173 |
+
print("Ground Truth: ", gt)
|
| 174 |
+
|
| 175 |
+
acc, f1, mcc = compute_metrics(gt, binding_site, L)
|
| 176 |
+
print(f"Accuracy={acc}\tF1={f1}\tMCC={mcc}")
|
| 177 |
+
|
| 178 |
+
# print("Prediction Logits: ", prediction[binding_site])
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
main()
|
classifier_code/__init__.py
ADDED
|
File without changes
|
classifier_code/binding_affinity_unpooled.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 8 |
+
from scipy.stats import spearmanr
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from transformers import AutoModel, AutoConfig, AutoTokenizer
|
| 16 |
+
class UnpooledBindingPredictor(nn.Module):
|
| 17 |
+
def __init__(self,
|
| 18 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 19 |
+
hidden_dim=512,
|
| 20 |
+
kernel_sizes=[3, 5, 7],
|
| 21 |
+
n_heads=8,
|
| 22 |
+
n_layers=3,
|
| 23 |
+
dropout=0.1,
|
| 24 |
+
freeze_esm=True):
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
# Define binding thresholds
|
| 28 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 29 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 30 |
+
|
| 31 |
+
# Load ESM model for computing embeddings on the fly
|
| 32 |
+
self.esm_model = AutoModel.from_pretrained(esm_model_name)
|
| 33 |
+
self.config = AutoConfig.from_pretrained(esm_model_name)
|
| 34 |
+
|
| 35 |
+
# Freeze ESM parameters if needed
|
| 36 |
+
if freeze_esm:
|
| 37 |
+
for param in self.esm_model.parameters():
|
| 38 |
+
param.requires_grad = False
|
| 39 |
+
|
| 40 |
+
# Get ESM hidden size
|
| 41 |
+
esm_dim = self.config.hidden_size
|
| 42 |
+
|
| 43 |
+
# Output channels for CNN layers
|
| 44 |
+
output_channels_per_kernel = 64
|
| 45 |
+
|
| 46 |
+
# CNN layers for handling variable length sequences
|
| 47 |
+
self.protein_conv_layers = nn.ModuleList([
|
| 48 |
+
nn.Conv1d(
|
| 49 |
+
in_channels=esm_dim,
|
| 50 |
+
out_channels=output_channels_per_kernel,
|
| 51 |
+
kernel_size=k,
|
| 52 |
+
padding='same'
|
| 53 |
+
) for k in kernel_sizes
|
| 54 |
+
])
|
| 55 |
+
|
| 56 |
+
self.binder_conv_layers = nn.ModuleList([
|
| 57 |
+
nn.Conv1d(
|
| 58 |
+
in_channels=esm_dim,
|
| 59 |
+
out_channels=output_channels_per_kernel,
|
| 60 |
+
kernel_size=k,
|
| 61 |
+
padding='same'
|
| 62 |
+
) for k in kernel_sizes
|
| 63 |
+
])
|
| 64 |
+
|
| 65 |
+
# Calculate total features after convolution and pooling
|
| 66 |
+
total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
|
| 67 |
+
|
| 68 |
+
# Project to same dimension after CNN processing
|
| 69 |
+
self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 70 |
+
self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 71 |
+
|
| 72 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 73 |
+
self.binder_norm = nn.LayerNorm(hidden_dim)
|
| 74 |
+
|
| 75 |
+
# Cross attention blocks with layer norm
|
| 76 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 77 |
+
nn.ModuleDict({
|
| 78 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 79 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 80 |
+
'ffn': nn.Sequential(
|
| 81 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 82 |
+
nn.ReLU(),
|
| 83 |
+
nn.Dropout(dropout),
|
| 84 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 85 |
+
),
|
| 86 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 87 |
+
}) for _ in range(n_layers)
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
# Prediction heads
|
| 91 |
+
self.shared_head = nn.Sequential(
|
| 92 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 93 |
+
nn.ReLU(),
|
| 94 |
+
nn.Dropout(dropout),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Regression head
|
| 98 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 99 |
+
|
| 100 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 101 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 102 |
+
|
| 103 |
+
def get_binding_class(self, affinity):
|
| 104 |
+
"""Convert affinity values to class indices
|
| 105 |
+
0: tight binding (>= 7.5)
|
| 106 |
+
1: medium binding (6.0-7.5)
|
| 107 |
+
2: weak binding (< 6.0)
|
| 108 |
+
"""
|
| 109 |
+
if isinstance(affinity, torch.Tensor):
|
| 110 |
+
tight_mask = affinity >= self.tight_threshold
|
| 111 |
+
weak_mask = affinity < self.weak_threshold
|
| 112 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 113 |
+
|
| 114 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 115 |
+
classes[medium_mask] = 1
|
| 116 |
+
classes[weak_mask] = 2
|
| 117 |
+
return classes
|
| 118 |
+
else:
|
| 119 |
+
if affinity >= self.tight_threshold:
|
| 120 |
+
return 0 # tight binding
|
| 121 |
+
elif affinity < self.weak_threshold:
|
| 122 |
+
return 2 # weak binding
|
| 123 |
+
else:
|
| 124 |
+
return 1 # medium binding
|
| 125 |
+
|
| 126 |
+
def compute_embeddings(self, input_ids, attention_mask=None):
|
| 127 |
+
"""Compute ESM embeddings on the fly"""
|
| 128 |
+
esm_outputs = self.esm_model(
|
| 129 |
+
input_ids=input_ids,
|
| 130 |
+
attention_mask=attention_mask,
|
| 131 |
+
return_dict=True
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
|
| 135 |
+
return esm_outputs.last_hidden_state
|
| 136 |
+
|
| 137 |
+
def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
|
| 138 |
+
"""Process a sequence through CNN layers and pooling"""
|
| 139 |
+
# Transpose for CNN: [batch_size, hidden_size, seq_length]
|
| 140 |
+
x = unpooled_emb.transpose(1, 2)
|
| 141 |
+
|
| 142 |
+
# Apply CNN layers and collect outputs
|
| 143 |
+
conv_outputs = []
|
| 144 |
+
for conv in conv_layers:
|
| 145 |
+
conv_out = F.relu(conv(x))
|
| 146 |
+
conv_outputs.append(conv_out)
|
| 147 |
+
|
| 148 |
+
# Concatenate along channel dimension
|
| 149 |
+
conv_output = torch.cat(conv_outputs, dim=1)
|
| 150 |
+
|
| 151 |
+
# Global pooling (both max and average)
|
| 152 |
+
# If attention mask is provided, use it to create a proper mask for pooling
|
| 153 |
+
if attention_mask is not None:
|
| 154 |
+
# Create a mask for pooling (1 for valid positions, 0 for padding)
|
| 155 |
+
# Expand mask to match conv_output channels
|
| 156 |
+
expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
|
| 157 |
+
|
| 158 |
+
# Apply mask (set padding to large negative value for max pooling)
|
| 159 |
+
masked_output = conv_output.clone()
|
| 160 |
+
masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
|
| 161 |
+
|
| 162 |
+
# Max pooling along sequence dimension
|
| 163 |
+
max_pooled = torch.max(masked_output, dim=2)[0]
|
| 164 |
+
|
| 165 |
+
# Average pooling (sum divided by number of valid positions)
|
| 166 |
+
sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
|
| 167 |
+
valid_positions = torch.sum(expanded_mask, dim=2)
|
| 168 |
+
valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
|
| 169 |
+
avg_pooled = sum_pooled / valid_positions
|
| 170 |
+
else:
|
| 171 |
+
# If no mask, use standard pooling
|
| 172 |
+
max_pooled = torch.max(conv_output, dim=2)[0]
|
| 173 |
+
avg_pooled = torch.mean(conv_output, dim=2)
|
| 174 |
+
|
| 175 |
+
# Concatenate the pooled features
|
| 176 |
+
pooled = torch.cat([max_pooled, avg_pooled], dim=1)
|
| 177 |
+
|
| 178 |
+
return pooled
|
| 179 |
+
|
| 180 |
+
def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
|
| 181 |
+
# Compute embeddings on the fly using the ESM model
|
| 182 |
+
protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
|
| 183 |
+
binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
|
| 184 |
+
|
| 185 |
+
# Process protein and binder sequences through CNN layers
|
| 186 |
+
protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
|
| 187 |
+
binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
|
| 188 |
+
|
| 189 |
+
# Project to same dimension
|
| 190 |
+
protein = self.protein_norm(self.protein_projection(protein_features))
|
| 191 |
+
binder = self.binder_norm(self.binder_projection(binder_features))
|
| 192 |
+
|
| 193 |
+
# Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
|
| 194 |
+
protein = protein.unsqueeze(0)
|
| 195 |
+
binder = binder.unsqueeze(0)
|
| 196 |
+
|
| 197 |
+
# Cross attention layers
|
| 198 |
+
for layer in self.cross_attention_layers:
|
| 199 |
+
# Protein attending to binder
|
| 200 |
+
attended_protein = layer['attention'](
|
| 201 |
+
protein, binder, binder
|
| 202 |
+
)[0]
|
| 203 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 204 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 205 |
+
|
| 206 |
+
# Binder attending to protein
|
| 207 |
+
attended_binder = layer['attention'](
|
| 208 |
+
binder, protein, protein
|
| 209 |
+
)[0]
|
| 210 |
+
binder = layer['norm1'](binder + attended_binder)
|
| 211 |
+
binder = layer['norm2'](binder + layer['ffn'](binder))
|
| 212 |
+
|
| 213 |
+
# Remove sequence dimension
|
| 214 |
+
protein_pool = protein.squeeze(0)
|
| 215 |
+
binder_pool = binder.squeeze(0)
|
| 216 |
+
|
| 217 |
+
# Concatenate both representations
|
| 218 |
+
combined = torch.cat([protein_pool, binder_pool], dim=-1)
|
| 219 |
+
|
| 220 |
+
# Shared features
|
| 221 |
+
shared_features = self.shared_head(combined)
|
| 222 |
+
|
| 223 |
+
regression_output = self.regression_head(shared_features)
|
| 224 |
+
classification_logits = self.classification_head(shared_features)
|
| 225 |
+
|
| 226 |
+
return regression_output, classification_logits
|
| 227 |
+
|
| 228 |
+
def load_model(checkpoint_path, device):
|
| 229 |
+
"""Load trained model from checkpoint."""
|
| 230 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 231 |
+
# Import the model class from your module or redefine it here
|
| 232 |
+
|
| 233 |
+
# Initialize model with the same parameters used during training
|
| 234 |
+
model = UnpooledBindingPredictor(
|
| 235 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 236 |
+
hidden_dim=384,
|
| 237 |
+
kernel_sizes=[3, 5, 7],
|
| 238 |
+
n_heads=8,
|
| 239 |
+
n_layers=4,
|
| 240 |
+
dropout=0.14561457009902096,
|
| 241 |
+
freeze_esm=True
|
| 242 |
+
).to(device)
|
| 243 |
+
|
| 244 |
+
# Load the trained weights
|
| 245 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 246 |
+
model.eval() # Set to evaluation mode
|
| 247 |
+
|
| 248 |
+
return model
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def prepare_inputs(protein_sequence, binder_sequence, tokenizer, max_length=1024, device='cuda'):
|
| 252 |
+
"""Tokenize protein and binder sequences."""
|
| 253 |
+
protein_tokens = tokenizer(
|
| 254 |
+
protein_sequence,
|
| 255 |
+
return_tensors="pt",
|
| 256 |
+
padding="max_length",
|
| 257 |
+
max_length=max_length,
|
| 258 |
+
truncation=True
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
binder_tokens = tokenizer(
|
| 262 |
+
binder_sequence,
|
| 263 |
+
return_tensors="pt",
|
| 264 |
+
padding="max_length",
|
| 265 |
+
max_length=max_length,
|
| 266 |
+
truncation=True
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return {
|
| 270 |
+
'protein_input_ids': protein_tokens['input_ids'].to(device),
|
| 271 |
+
'protein_attention_mask': protein_tokens['attention_mask'].to(device),
|
| 272 |
+
'binder_input_ids': binder_tokens['input_ids'].to(device),
|
| 273 |
+
'binder_attention_mask': binder_tokens['attention_mask'].to(device)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
# Perform prediction
|
| 277 |
+
def predict_binding(model, protein_sequence, binder_sequence, device='cuda'):
|
| 278 |
+
"""Predict binding affinity between protein and binder sequences."""
|
| 279 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 280 |
+
inputs = prepare_inputs(protein_sequence, binder_sequence, tokenizer, device=device)
|
| 281 |
+
|
| 282 |
+
with torch.no_grad():
|
| 283 |
+
regression_output, classification_logits = model(
|
| 284 |
+
inputs['protein_input_ids'],
|
| 285 |
+
inputs['binder_input_ids'],
|
| 286 |
+
inputs['protein_attention_mask'],
|
| 287 |
+
inputs['binder_attention_mask']
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Get numerical prediction (pKd/pKi)
|
| 291 |
+
predicted_affinity = regression_output.item()
|
| 292 |
+
|
| 293 |
+
# Get classification prediction (tight, medium, weak)
|
| 294 |
+
predicted_class_idx = torch.argmax(classification_logits, dim=1).item()
|
| 295 |
+
class_names = ['Tight binding', 'Medium binding', 'Weak binding']
|
| 296 |
+
predicted_class = class_names[predicted_class_idx]
|
| 297 |
+
|
| 298 |
+
# Get class probabilities
|
| 299 |
+
class_probs = F.softmax(classification_logits, dim=1).cpu().numpy()[0]
|
| 300 |
+
|
| 301 |
+
return {
|
| 302 |
+
'predicted_affinity': predicted_affinity,
|
| 303 |
+
'binding_class': predicted_class,
|
| 304 |
+
'class_probabilities': {name: prob for name, prob in zip(class_names, class_probs)},
|
| 305 |
+
'tight_threshold': model.tight_threshold, # 7.5 (≤ ~30nM)
|
| 306 |
+
'weak_threshold': model.weak_threshold # 6.0 (> 1μM)
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
# Example usage
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
# Set device
|
| 312 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 313 |
+
|
| 314 |
+
# Load the model
|
| 315 |
+
model = load_model('../classifier_ckpt/binding_affinity_unpooled.pt', device)
|
| 316 |
+
|
| 317 |
+
protein_sequence = "GSHMIEPNVISVRLFKRKVGGLGFLVKERVSKPPVIISDLIRGGAAEQSGLIQAGDIILAVNDRPLVDLSYDSALEVLRGIASETHVVLILRGPEGFTTHLETTFTGDGTPKTIRVTQPLGPPTKAV"
|
| 318 |
+
binder_sequence = "VVKVDSV"
|
| 319 |
+
|
| 320 |
+
result = predict_binding(model, protein_sequence, binder, device)
|
| 321 |
+
print(f"Affinity Score: {result['predicted_affinity']}")
|
classifier_code/binding_affinity_unpooled_2.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 8 |
+
from scipy.stats import spearmanr
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from transformers import AutoModel, AutoConfig, AutoTokenizer
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
# point HF_ENDPOINT at your mirror
|
| 20 |
+
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 21 |
+
|
| 22 |
+
class UnpooledBindingPredictor(nn.Module):
|
| 23 |
+
def __init__(self,
|
| 24 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 25 |
+
hidden_dim=512,
|
| 26 |
+
kernel_sizes=[3, 5, 7],
|
| 27 |
+
n_heads=8,
|
| 28 |
+
n_layers=3,
|
| 29 |
+
dropout=0.1,
|
| 30 |
+
freeze_esm=True):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
# Define binding thresholds
|
| 34 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 35 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 36 |
+
|
| 37 |
+
# Load ESM model for computing embeddings on the fly
|
| 38 |
+
self.esm_model = AutoModel.from_pretrained(esm_model_name)
|
| 39 |
+
self.config = AutoConfig.from_pretrained(esm_model_name)
|
| 40 |
+
|
| 41 |
+
# Freeze ESM parameters if needed
|
| 42 |
+
if freeze_esm:
|
| 43 |
+
for param in self.esm_model.parameters():
|
| 44 |
+
param.requires_grad = False
|
| 45 |
+
|
| 46 |
+
# Get ESM hidden size
|
| 47 |
+
esm_dim = self.config.hidden_size
|
| 48 |
+
|
| 49 |
+
# Output channels for CNN layers
|
| 50 |
+
output_channels_per_kernel = 64
|
| 51 |
+
|
| 52 |
+
# CNN layers for handling variable length sequences
|
| 53 |
+
self.protein_conv_layers = nn.ModuleList([
|
| 54 |
+
nn.Conv1d(
|
| 55 |
+
in_channels=esm_dim,
|
| 56 |
+
out_channels=output_channels_per_kernel,
|
| 57 |
+
kernel_size=k,
|
| 58 |
+
padding='same'
|
| 59 |
+
) for k in kernel_sizes
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
self.binder_conv_layers = nn.ModuleList([
|
| 63 |
+
nn.Conv1d(
|
| 64 |
+
in_channels=esm_dim,
|
| 65 |
+
out_channels=output_channels_per_kernel,
|
| 66 |
+
kernel_size=k,
|
| 67 |
+
padding='same'
|
| 68 |
+
) for k in kernel_sizes
|
| 69 |
+
])
|
| 70 |
+
|
| 71 |
+
# Calculate total features after convolution and pooling
|
| 72 |
+
total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
|
| 73 |
+
|
| 74 |
+
# Project to same dimension after CNN processing
|
| 75 |
+
self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 76 |
+
self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 77 |
+
|
| 78 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 79 |
+
self.binder_norm = nn.LayerNorm(hidden_dim)
|
| 80 |
+
|
| 81 |
+
# Cross attention blocks with layer norm
|
| 82 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 83 |
+
nn.ModuleDict({
|
| 84 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 85 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 86 |
+
'ffn': nn.Sequential(
|
| 87 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.Dropout(dropout),
|
| 90 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 91 |
+
),
|
| 92 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 93 |
+
}) for _ in range(n_layers)
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
# Prediction heads
|
| 97 |
+
self.shared_head = nn.Sequential(
|
| 98 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 99 |
+
nn.ReLU(),
|
| 100 |
+
nn.Dropout(dropout),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Regression head
|
| 104 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 105 |
+
|
| 106 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 107 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 108 |
+
|
| 109 |
+
def get_binding_class(self, affinity):
|
| 110 |
+
"""Convert affinity values to class indices
|
| 111 |
+
0: tight binding (>= 7.5)
|
| 112 |
+
1: medium binding (6.0-7.5)
|
| 113 |
+
2: weak binding (< 6.0)
|
| 114 |
+
"""
|
| 115 |
+
if isinstance(affinity, torch.Tensor):
|
| 116 |
+
tight_mask = affinity >= self.tight_threshold
|
| 117 |
+
weak_mask = affinity < self.weak_threshold
|
| 118 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 119 |
+
|
| 120 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 121 |
+
classes[medium_mask] = 1
|
| 122 |
+
classes[weak_mask] = 2
|
| 123 |
+
return classes
|
| 124 |
+
else:
|
| 125 |
+
if affinity >= self.tight_threshold:
|
| 126 |
+
return 0 # tight binding
|
| 127 |
+
elif affinity < self.weak_threshold:
|
| 128 |
+
return 2 # weak binding
|
| 129 |
+
else:
|
| 130 |
+
return 1 # medium binding
|
| 131 |
+
|
| 132 |
+
def compute_embeddings(self, input_ids, attention_mask=None):
|
| 133 |
+
"""Compute ESM embeddings on the fly"""
|
| 134 |
+
esm_outputs = self.esm_model(
|
| 135 |
+
input_ids=input_ids,
|
| 136 |
+
attention_mask=attention_mask,
|
| 137 |
+
return_dict=True
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
|
| 141 |
+
return esm_outputs.last_hidden_state
|
| 142 |
+
|
| 143 |
+
def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
|
| 144 |
+
"""Process a sequence through CNN layers and pooling"""
|
| 145 |
+
# Transpose for CNN: [batch_size, hidden_size, seq_length]
|
| 146 |
+
x = unpooled_emb.transpose(1, 2)
|
| 147 |
+
|
| 148 |
+
# Apply CNN layers and collect outputs
|
| 149 |
+
conv_outputs = []
|
| 150 |
+
for conv in conv_layers:
|
| 151 |
+
conv_out = F.relu(conv(x))
|
| 152 |
+
conv_outputs.append(conv_out)
|
| 153 |
+
|
| 154 |
+
# Concatenate along channel dimension
|
| 155 |
+
conv_output = torch.cat(conv_outputs, dim=1)
|
| 156 |
+
|
| 157 |
+
# Global pooling (both max and average)
|
| 158 |
+
# If attention mask is provided, use it to create a proper mask for pooling
|
| 159 |
+
if attention_mask is not None:
|
| 160 |
+
# Create a mask for pooling (1 for valid positions, 0 for padding)
|
| 161 |
+
# Expand mask to match conv_output channels
|
| 162 |
+
expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
|
| 163 |
+
|
| 164 |
+
# Apply mask (set padding to large negative value for max pooling)
|
| 165 |
+
masked_output = conv_output.clone()
|
| 166 |
+
masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
|
| 167 |
+
|
| 168 |
+
# Max pooling along sequence dimension
|
| 169 |
+
max_pooled = torch.max(masked_output, dim=2)[0]
|
| 170 |
+
|
| 171 |
+
# Average pooling (sum divided by number of valid positions)
|
| 172 |
+
sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
|
| 173 |
+
valid_positions = torch.sum(expanded_mask, dim=2)
|
| 174 |
+
valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
|
| 175 |
+
avg_pooled = sum_pooled / valid_positions
|
| 176 |
+
else:
|
| 177 |
+
# If no mask, use standard pooling
|
| 178 |
+
max_pooled = torch.max(conv_output, dim=2)[0]
|
| 179 |
+
avg_pooled = torch.mean(conv_output, dim=2)
|
| 180 |
+
|
| 181 |
+
# Concatenate the pooled features
|
| 182 |
+
pooled = torch.cat([max_pooled, avg_pooled], dim=1)
|
| 183 |
+
|
| 184 |
+
return pooled
|
| 185 |
+
|
| 186 |
+
def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
|
| 187 |
+
# Compute embeddings on the fly using the ESM model
|
| 188 |
+
protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
|
| 189 |
+
binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
|
| 190 |
+
|
| 191 |
+
# Process protein and binder sequences through CNN layers
|
| 192 |
+
protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
|
| 193 |
+
binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
|
| 194 |
+
|
| 195 |
+
# Project to same dimension
|
| 196 |
+
protein = self.protein_norm(self.protein_projection(protein_features))
|
| 197 |
+
binder = self.binder_norm(self.binder_projection(binder_features))
|
| 198 |
+
|
| 199 |
+
# Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
|
| 200 |
+
protein = protein.unsqueeze(0)
|
| 201 |
+
binder = binder.unsqueeze(0)
|
| 202 |
+
|
| 203 |
+
# Cross attention layers
|
| 204 |
+
for layer in self.cross_attention_layers:
|
| 205 |
+
# Protein attending to binder
|
| 206 |
+
attended_protein = layer['attention'](
|
| 207 |
+
protein, binder, binder
|
| 208 |
+
)[0]
|
| 209 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 210 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 211 |
+
|
| 212 |
+
# Binder attending to protein
|
| 213 |
+
attended_binder = layer['attention'](
|
| 214 |
+
binder, protein, protein
|
| 215 |
+
)[0]
|
| 216 |
+
binder = layer['norm1'](binder + attended_binder)
|
| 217 |
+
binder = layer['norm2'](binder + layer['ffn'](binder))
|
| 218 |
+
|
| 219 |
+
# Remove sequence dimension
|
| 220 |
+
protein_pool = protein.squeeze(0)
|
| 221 |
+
binder_pool = binder.squeeze(0)
|
| 222 |
+
|
| 223 |
+
# Concatenate both representations
|
| 224 |
+
combined = torch.cat([protein_pool, binder_pool], dim=-1)
|
| 225 |
+
|
| 226 |
+
# Shared features
|
| 227 |
+
shared_features = self.shared_head(combined)
|
| 228 |
+
|
| 229 |
+
regression_output = self.regression_head(shared_features)
|
| 230 |
+
classification_logits = self.classification_head(shared_features)
|
| 231 |
+
|
| 232 |
+
return regression_output, classification_logits
|
| 233 |
+
|
| 234 |
+
def load_model(checkpoint_path, device):
|
| 235 |
+
"""Load trained model from checkpoint."""
|
| 236 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 237 |
+
# Import the model class from your module or redefine it here
|
| 238 |
+
|
| 239 |
+
# Initialize model with the same parameters used during training
|
| 240 |
+
model = UnpooledBindingPredictor(
|
| 241 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 242 |
+
hidden_dim=384,
|
| 243 |
+
kernel_sizes=[3, 5, 7],
|
| 244 |
+
n_heads=8,
|
| 245 |
+
n_layers=4,
|
| 246 |
+
dropout=0.14561457009902096,
|
| 247 |
+
freeze_esm=True
|
| 248 |
+
).to(device)
|
| 249 |
+
|
| 250 |
+
# Load the trained weights
|
| 251 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 252 |
+
model.eval() # Set to evaluation mode
|
| 253 |
+
|
| 254 |
+
return model
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def prepare_inputs(protein_sequence, binder_sequence, tokenizer, max_length=1024, device='cuda'):
|
| 258 |
+
"""Tokenize protein and binder sequences."""
|
| 259 |
+
protein_tokens = tokenizer(
|
| 260 |
+
protein_sequence,
|
| 261 |
+
return_tensors="pt",
|
| 262 |
+
padding="max_length",
|
| 263 |
+
max_length=max_length,
|
| 264 |
+
truncation=True
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
binder_tokens = tokenizer(
|
| 268 |
+
binder_sequence,
|
| 269 |
+
return_tensors="pt",
|
| 270 |
+
padding="max_length",
|
| 271 |
+
max_length=max_length,
|
| 272 |
+
truncation=True
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
return {
|
| 276 |
+
'protein_input_ids': protein_tokens['input_ids'].to(device),
|
| 277 |
+
'protein_attention_mask': protein_tokens['attention_mask'].to(device),
|
| 278 |
+
'binder_input_ids': binder_tokens['input_ids'].to(device),
|
| 279 |
+
'binder_attention_mask': binder_tokens['attention_mask'].to(device)
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
# Perform prediction
|
| 283 |
+
def predict_binding(model, protein_sequence, binder_sequence, device='cuda'):
|
| 284 |
+
"""Predict binding affinity between protein and binder sequences."""
|
| 285 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 286 |
+
inputs = prepare_inputs(protein_sequence, binder_sequence, tokenizer, device=device)
|
| 287 |
+
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
regression_output, classification_logits = model(
|
| 290 |
+
inputs['protein_input_ids'],
|
| 291 |
+
inputs['binder_input_ids'],
|
| 292 |
+
inputs['protein_attention_mask'],
|
| 293 |
+
inputs['binder_attention_mask']
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Get numerical prediction (pKd/pKi)
|
| 297 |
+
predicted_affinity = regression_output.item()
|
| 298 |
+
|
| 299 |
+
# Get classification prediction (tight, medium, weak)
|
| 300 |
+
predicted_class_idx = torch.argmax(classification_logits, dim=1).item()
|
| 301 |
+
class_names = ['Tight binding', 'Medium binding', 'Weak binding']
|
| 302 |
+
predicted_class = class_names[predicted_class_idx]
|
| 303 |
+
|
| 304 |
+
# Get class probabilities
|
| 305 |
+
class_probs = F.softmax(classification_logits, dim=1).cpu().numpy()[0]
|
| 306 |
+
|
| 307 |
+
return {
|
| 308 |
+
'predicted_affinity': predicted_affinity,
|
| 309 |
+
'binding_class': predicted_class,
|
| 310 |
+
'class_probabilities': {name: prob for name, prob in zip(class_names, class_probs)},
|
| 311 |
+
'tight_threshold': model.tight_threshold, # 7.5 (≤ ~30nM)
|
| 312 |
+
'weak_threshold': model.weak_threshold # 6.0 (> 1μM)
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
# Example usage
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
# Set device
|
| 318 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 319 |
+
|
| 320 |
+
# Load the model
|
| 321 |
+
model = load_model('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/binding_affinity_unpooled.pt', device)
|
| 322 |
+
|
| 323 |
+
# Example protein sequences (replace with actual sequences)
|
| 324 |
+
binders = ['GLSKGCFGLKLDRIGSMSGLGC', 'RGLSDGFLKLKMGISGSLGC']
|
| 325 |
+
protein_sequence = "RNLTLAVVLPEHNLSYAWAWPRVGPAVALAVEALGRALPVDLRFVSSELEGACSEYLAPLSAVDLKLYHDPDLLLGPGCVYPAASVARFASHWRLPLLTAGAVASGFSAKNDHYRTLVRTGPSAPKLGEFVVTLHGHFNWTARAALLYLDARTDDRPHYFTIEGVFEALQGSNLSVQHQVYAREPGGPEQATHFIRANGRIVYICGPLEMLHEILLQAQRENLTNGDYVFFYLDVFGESLRAGPTRATGRPWQDNRTREQAQALREAFQTVLVITYREPPNPEYQEFQNRLLIRAREDFGVELGPSLMNLIAGCFYDGILLYAEVLNETIQEGGTREDGLRIVEKMQGRRYHGVTGLVVMDKNNDRETDFVLWAMGDLDSGDFQPAAHYSGAEKQIWWTGRPIPWVKGAPPSDNPPCAFDLDDPSCDKTPLSTLAI"
|
| 326 |
+
|
| 327 |
+
# name = "CLIC1_10_moppit"
|
| 328 |
+
# print(name)
|
| 329 |
+
# with open(f'/home/tc415/flow_matching/samples/unconditional_samples/12.txt', 'r') as f:
|
| 330 |
+
# binders = f.readlines()
|
| 331 |
+
# binders = [binder.strip() for binder in binders]
|
| 332 |
+
# binders = binders[:100]
|
| 333 |
+
|
| 334 |
+
# # Make prediction
|
| 335 |
+
affinities = []
|
| 336 |
+
for binder in binders:
|
| 337 |
+
result = predict_binding(model, protein_sequence, binder, device)
|
| 338 |
+
print(result['predicted_affinity'])
|
| 339 |
+
affinities.append(result['predicted_affinity'])
|
| 340 |
+
|
| 341 |
+
# with open('/home/tc415/flow_matching/scores/affinity/EWSFLI1_12_unconditional.txt', 'w') as f:
|
| 342 |
+
# for score in affinities:
|
| 343 |
+
# f.write(str(score) + '\n')
|
| 344 |
+
|
| 345 |
+
# print(sum(affinities) / len(affinities))
|
| 346 |
+
|
| 347 |
+
# with open(f'/home/tc415/flow_matching/scores/affinity/{name}.txt', 'w') as f:
|
| 348 |
+
# for score in affinities:
|
| 349 |
+
# f.write(str(round(score, 4)) + '\n')
|
| 350 |
+
|
| 351 |
+
# Display results
|
| 352 |
+
# print(f"Predicted binding affinity (pKd/pKi): {result['predicted_affinity']:.2f}")
|
| 353 |
+
# print(f"Binding class: {result['binding_class']}")
|
| 354 |
+
# print("Class probabilities:")
|
| 355 |
+
# for class_name, prob in result['class_probabilities'].items():
|
| 356 |
+
# print(f" {class_name}: {prob:.2f}")
|
classifier_code/half_life.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
from transformers import EsmModel, EsmTokenizer
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import pdb
|
| 7 |
+
|
| 8 |
+
class PeptideCNN(nn.Module):
|
| 9 |
+
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
|
| 12 |
+
self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
|
| 13 |
+
self.fc = nn.Linear(hidden_dims[1], output_dim)
|
| 14 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 15 |
+
self.predictor = nn.Linear(output_dim, 1) # For regression/classification
|
| 16 |
+
|
| 17 |
+
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 18 |
+
self.esm_model.eval()
|
| 19 |
+
|
| 20 |
+
def forward(self, input_ids, attention_mask=None, return_features=False):
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
x = self.esm_model(input_ids, attention_mask).last_hidden_state
|
| 23 |
+
# pdb.set_trace()
|
| 24 |
+
# x shape: (B, L, input_dim)
|
| 25 |
+
x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
|
| 26 |
+
x = nn.functional.relu(self.conv1(x))
|
| 27 |
+
x = self.dropout(x)
|
| 28 |
+
x = nn.functional.relu(self.conv2(x))
|
| 29 |
+
x = self.dropout(x)
|
| 30 |
+
x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
|
| 31 |
+
|
| 32 |
+
# Global average pooling over the sequence dimension (L)
|
| 33 |
+
x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
|
| 34 |
+
|
| 35 |
+
features = self.fc(x) # features shape: (B, output_dim)
|
| 36 |
+
if return_features:
|
| 37 |
+
return features
|
| 38 |
+
return self.predictor(features) # Output shape: (B, 1)
|
| 39 |
+
|
| 40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 41 |
+
|
| 42 |
+
input_dim = 1280
|
| 43 |
+
hidden_dims = [input_dim // 2, input_dim // 4]
|
| 44 |
+
output_dim = input_dim // 8
|
| 45 |
+
dropout_rate = 0.3
|
| 46 |
+
|
| 47 |
+
nn_model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
|
| 48 |
+
nn_model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth'))
|
| 49 |
+
nn_model.eval()
|
| 50 |
+
|
| 51 |
+
def predict(inputs):
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
prediction = nn_model(**inputs, return_features=False)
|
| 54 |
+
|
| 55 |
+
return prediction.item()
|
| 56 |
+
|
| 57 |
+
if __name__ == '__main__':
|
| 58 |
+
sequence = 'RGLSDGFLKLKMGISGSLGC'
|
| 59 |
+
|
| 60 |
+
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 61 |
+
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
| 62 |
+
|
| 63 |
+
prediction = predict(inputs)
|
| 64 |
+
print(prediction)
|
| 65 |
+
print(f"Predicted half life of {sequence} is {(10**prediction):.4f} h")
|
classifier_code/hemolysis_wt.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
|
| 4 |
+
import xgboost as xgb
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 10 |
+
from transformers import AutoTokenizer, EsmModel
|
| 11 |
+
|
| 12 |
+
rdBase.DisableLog('rdApp.error')
|
| 13 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Hemolysis:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
# change model path
|
| 21 |
+
self.predictor = xgb.Booster(model_file='/home/tc415/flow_matching/classifier_ckpt/best_model_hemolysis.json')
|
| 22 |
+
|
| 23 |
+
# Load ESM model and tokenizer
|
| 24 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 25 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 26 |
+
self.model.eval()
|
| 27 |
+
|
| 28 |
+
def generate_embeddings(self, sequences):
|
| 29 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 30 |
+
embeddings = []
|
| 31 |
+
|
| 32 |
+
# Process sequences in batches to avoid memory issues
|
| 33 |
+
batch_size = 8
|
| 34 |
+
for i in range(0, len(sequences), batch_size):
|
| 35 |
+
batch_sequences = sequences[i:i + batch_size]
|
| 36 |
+
|
| 37 |
+
inputs = self.tokenizer(
|
| 38 |
+
batch_sequences,
|
| 39 |
+
padding=True,
|
| 40 |
+
truncation=True,
|
| 41 |
+
return_tensors="pt"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if torch.cuda.is_available():
|
| 45 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 46 |
+
self.model = self.model.cuda()
|
| 47 |
+
|
| 48 |
+
# Generate embeddings
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
outputs = self.model(**inputs)
|
| 51 |
+
|
| 52 |
+
# Get last hidden states
|
| 53 |
+
last_hidden_states = outputs.last_hidden_state
|
| 54 |
+
# pdb.set_trace()
|
| 55 |
+
# Compute mean pooling (excluding padding tokens)
|
| 56 |
+
attention_mask = inputs['attention_mask'].unsqueeze(-1)
|
| 57 |
+
masked_hidden_states = last_hidden_states * attention_mask
|
| 58 |
+
sum_hidden_states = masked_hidden_states.sum(dim=1)
|
| 59 |
+
seq_lengths = attention_mask.sum(dim=1)
|
| 60 |
+
batch_embeddings = sum_hidden_states / seq_lengths
|
| 61 |
+
|
| 62 |
+
batch_embeddings = batch_embeddings.cpu().numpy()
|
| 63 |
+
embeddings.append(batch_embeddings)
|
| 64 |
+
|
| 65 |
+
if embeddings:
|
| 66 |
+
return np.vstack(embeddings)
|
| 67 |
+
else:
|
| 68 |
+
return np.array([])
|
| 69 |
+
|
| 70 |
+
def get_scores(self, input_seqs: list):
|
| 71 |
+
scores = np.ones(len(input_seqs))
|
| 72 |
+
features = self.generate_embeddings(input_seqs)
|
| 73 |
+
|
| 74 |
+
if len(features) == 0:
|
| 75 |
+
return scores
|
| 76 |
+
|
| 77 |
+
features = np.nan_to_num(features, nan=0.)
|
| 78 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 79 |
+
|
| 80 |
+
features = xgb.DMatrix(features)
|
| 81 |
+
|
| 82 |
+
probs = self.predictor.predict(features)
|
| 83 |
+
# return the probability of it being not hemolytic
|
| 84 |
+
return scores - probs
|
| 85 |
+
|
| 86 |
+
def __call__(self, input_seqs: list):
|
| 87 |
+
scores = self.get_scores(input_seqs)
|
| 88 |
+
return scores
|
| 89 |
+
|
| 90 |
+
def unittest():
|
| 91 |
+
hemolysis = Hemolysis()
|
| 92 |
+
sequences = [
|
| 93 |
+
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
|
| 94 |
+
"MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD"
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
scores = hemolysis(input_seqs=sequences)
|
| 98 |
+
print([1-score for score in scores])
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
unittest()
|
classifier_code/nonfouling_wt.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import warnings
|
| 7 |
+
import numpy as np
|
| 8 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 9 |
+
from transformers import AutoTokenizer, EsmModel
|
| 10 |
+
|
| 11 |
+
rdBase.DisableLog('rdApp.error')
|
| 12 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 13 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 15 |
+
|
| 16 |
+
class Nonfouling:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
# change model path
|
| 19 |
+
self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_nonfouling.json')
|
| 20 |
+
|
| 21 |
+
# Load ESM model and tokenizer
|
| 22 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 23 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 24 |
+
self.model.eval()
|
| 25 |
+
|
| 26 |
+
def generate_embeddings(self, sequences):
|
| 27 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 28 |
+
embeddings = []
|
| 29 |
+
|
| 30 |
+
# Process sequences in batches to avoid memory issues
|
| 31 |
+
batch_size = 8
|
| 32 |
+
for i in range(0, len(sequences), batch_size):
|
| 33 |
+
batch_sequences = sequences[i:i + batch_size]
|
| 34 |
+
|
| 35 |
+
inputs = self.tokenizer(
|
| 36 |
+
batch_sequences,
|
| 37 |
+
padding=True,
|
| 38 |
+
truncation=True,
|
| 39 |
+
return_tensors="pt"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 44 |
+
self.model = self.model.cuda()
|
| 45 |
+
|
| 46 |
+
# Generate embeddings
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = self.model(**inputs)
|
| 49 |
+
|
| 50 |
+
# Get last hidden states
|
| 51 |
+
last_hidden_states = outputs.last_hidden_state
|
| 52 |
+
|
| 53 |
+
# Compute mean pooling (excluding padding tokens)
|
| 54 |
+
attention_mask = inputs['attention_mask'].unsqueeze(-1)
|
| 55 |
+
masked_hidden_states = last_hidden_states * attention_mask
|
| 56 |
+
sum_hidden_states = masked_hidden_states.sum(dim=1)
|
| 57 |
+
seq_lengths = attention_mask.sum(dim=1)
|
| 58 |
+
batch_embeddings = sum_hidden_states / seq_lengths
|
| 59 |
+
|
| 60 |
+
batch_embeddings = batch_embeddings.cpu().numpy()
|
| 61 |
+
embeddings.append(batch_embeddings)
|
| 62 |
+
|
| 63 |
+
if embeddings:
|
| 64 |
+
return np.vstack(embeddings)
|
| 65 |
+
else:
|
| 66 |
+
return np.array([])
|
| 67 |
+
|
| 68 |
+
def get_scores(self, input_seqs: list):
|
| 69 |
+
scores = np.zeros(len(input_seqs))
|
| 70 |
+
features = self.generate_embeddings(input_seqs)
|
| 71 |
+
|
| 72 |
+
if len(features) == 0:
|
| 73 |
+
return scores
|
| 74 |
+
|
| 75 |
+
features = np.nan_to_num(features, nan=0.)
|
| 76 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 77 |
+
|
| 78 |
+
features = xgb.DMatrix(features)
|
| 79 |
+
|
| 80 |
+
scores = self.predictor.predict(features)
|
| 81 |
+
return scores
|
| 82 |
+
|
| 83 |
+
def __call__(self, input_seqs: list):
|
| 84 |
+
scores = self.get_scores(input_seqs)
|
| 85 |
+
return scores
|
| 86 |
+
|
| 87 |
+
def unittest():
|
| 88 |
+
nonfouling = Nonfouling()
|
| 89 |
+
sequences = [
|
| 90 |
+
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
|
| 91 |
+
"MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD"
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
scores = nonfouling(input_seqs=sequences)
|
| 95 |
+
print(scores)
|
| 96 |
+
|
| 97 |
+
if __name__ == '__main__':
|
| 98 |
+
unittest()
|
classifier_code/solubility_wt.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import warnings
|
| 7 |
+
import numpy as np
|
| 8 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 9 |
+
from transformers import AutoTokenizer, EsmModel
|
| 10 |
+
|
| 11 |
+
rdBase.DisableLog('rdApp.error')
|
| 12 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 13 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 15 |
+
|
| 16 |
+
class Solubility:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
# change model path
|
| 19 |
+
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
|
| 20 |
+
|
| 21 |
+
# Load ESM model and tokenizer
|
| 22 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 23 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 24 |
+
self.model.eval()
|
| 25 |
+
|
| 26 |
+
def generate_embeddings(self, sequences):
|
| 27 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 28 |
+
embeddings = []
|
| 29 |
+
|
| 30 |
+
# Process sequences in batches to avoid memory issues
|
| 31 |
+
batch_size = 8
|
| 32 |
+
for i in range(0, len(sequences), batch_size):
|
| 33 |
+
batch_sequences = sequences[i:i + batch_size]
|
| 34 |
+
|
| 35 |
+
inputs = self.tokenizer(
|
| 36 |
+
batch_sequences,
|
| 37 |
+
padding=True,
|
| 38 |
+
truncation=True,
|
| 39 |
+
return_tensors="pt"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 44 |
+
self.model = self.model.cuda()
|
| 45 |
+
|
| 46 |
+
# Generate embeddings
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = self.model(**inputs)
|
| 49 |
+
|
| 50 |
+
# Get last hidden states
|
| 51 |
+
last_hidden_states = outputs.last_hidden_state
|
| 52 |
+
|
| 53 |
+
# Compute mean pooling (excluding padding tokens)
|
| 54 |
+
attention_mask = inputs['attention_mask'].unsqueeze(-1)
|
| 55 |
+
masked_hidden_states = last_hidden_states * attention_mask
|
| 56 |
+
sum_hidden_states = masked_hidden_states.sum(dim=1)
|
| 57 |
+
seq_lengths = attention_mask.sum(dim=1)
|
| 58 |
+
batch_embeddings = sum_hidden_states / seq_lengths
|
| 59 |
+
|
| 60 |
+
batch_embeddings = batch_embeddings.cpu().numpy()
|
| 61 |
+
embeddings.append(batch_embeddings)
|
| 62 |
+
|
| 63 |
+
if embeddings:
|
| 64 |
+
return np.vstack(embeddings)
|
| 65 |
+
else:
|
| 66 |
+
return np.array([])
|
| 67 |
+
|
| 68 |
+
def get_scores(self, input_seqs: list):
|
| 69 |
+
scores = np.zeros(len(input_seqs))
|
| 70 |
+
features = self.generate_embeddings(input_seqs)
|
| 71 |
+
|
| 72 |
+
if len(features) == 0:
|
| 73 |
+
return scores
|
| 74 |
+
|
| 75 |
+
features = np.nan_to_num(features, nan=0.)
|
| 76 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 77 |
+
|
| 78 |
+
features = xgb.DMatrix(features)
|
| 79 |
+
|
| 80 |
+
scores = self.predictor.predict(features)
|
| 81 |
+
return scores
|
| 82 |
+
|
| 83 |
+
def __call__(self, input_seqs: list):
|
| 84 |
+
scores = self.get_scores(input_seqs)
|
| 85 |
+
return scores
|
| 86 |
+
|
| 87 |
+
def unittest():
|
| 88 |
+
solubility = Solubility()
|
| 89 |
+
sequences = [
|
| 90 |
+
"GLSKGCFGLKLDRIGSMSGLGC",
|
| 91 |
+
"RGLSDGFLKLKMGISGSLGC"
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
scores = solubility(input_seqs=sequences)
|
| 95 |
+
print(scores)
|
| 96 |
+
|
| 97 |
+
if __name__ == '__main__':
|
| 98 |
+
unittest()
|
flow_matching/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
__version__ = "1.0.10"
|
flow_matching/loss/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .generalized_loss import MixturePathGeneralizedKL
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"MixturePathGeneralizedKL",
|
| 11 |
+
]
|
flow_matching/loss/generalized_loss.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.nn.modules.loss import _Loss
|
| 10 |
+
|
| 11 |
+
from flow_matching.path import MixtureDiscreteProbPath
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MixturePathGeneralizedKL(_Loss):
|
| 15 |
+
r"""A generalized KL loss for discrete flow matching.
|
| 16 |
+
A class that measures the generalized KL of a discrete flow model :math:`p_{1|t}` w.r.t. a probability path given by ``path``. Note: this class is assuming that the model is trained on the same path.
|
| 17 |
+
|
| 18 |
+
For a model trained on a space :math:`\mathcal{S} = \mathcal{T}^d`, :math:`\mathcal{T} = [K] = \set{1,2,\ldots,K}`, the loss is given by
|
| 19 |
+
|
| 20 |
+
.. math::
|
| 21 |
+
\ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr],
|
| 22 |
+
|
| 23 |
+
where :math:`\kappa_t` is the scheduler associated with ``path``.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
path (MixtureDiscreteProbPath): Probability path (x-prediction training).
|
| 27 |
+
reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, path: MixtureDiscreteProbPath, reduction: str = "mean") -> None:
|
| 31 |
+
super().__init__(None, None, reduction)
|
| 32 |
+
self.path = path
|
| 33 |
+
|
| 34 |
+
def forward(self, logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
|
| 35 |
+
r"""Evaluates the generalized KL loss.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
logits (Tensor): posterior model output (i.e., softmax(``logits``) :math:`=p_{1|t}(x|x_t)`), shape (batch, d, K).
|
| 39 |
+
x_1 (Tensor): target data point :math:`x_1 \sim q`, shape (batch, d).
|
| 40 |
+
x_t (Tensor): conditional sample at :math:`x_t \sim p_t(\cdot|x_1)`, shape (batch, d).
|
| 41 |
+
t (Tensor): times in :math:`[0,1]`, shape (batch).
|
| 42 |
+
|
| 43 |
+
Raises:
|
| 44 |
+
ValueError: reduction value must be one of ``'none'`` | ``'mean'`` | ``'sum'``.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tensor: Generalized KL loss.
|
| 48 |
+
"""
|
| 49 |
+
x_1_shape = x_1.shape
|
| 50 |
+
|
| 51 |
+
# extract x_1 value of log(p_{1|t}(x|x_t)).
|
| 52 |
+
log_p_1t = torch.log_softmax(logits, dim=-1)
|
| 53 |
+
log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1))
|
| 54 |
+
log_p_1t_x1 = log_p_1t_x1.view(*x_1_shape)
|
| 55 |
+
|
| 56 |
+
# extract x_t value of p_{1|t}(x|x_t).
|
| 57 |
+
p_1t = torch.exp(log_p_1t)
|
| 58 |
+
p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1))
|
| 59 |
+
p_1t_xt = p_1t_xt.view(*x_1_shape)
|
| 60 |
+
|
| 61 |
+
scheduler_output = self.path.scheduler(t)
|
| 62 |
+
|
| 63 |
+
jump_coefficient = (
|
| 64 |
+
scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t)
|
| 65 |
+
)[(...,) + (None,) * (x_1.dim() - 1)]
|
| 66 |
+
jump_coefficient = jump_coefficient.repeat(1, *x_1_shape[1:])
|
| 67 |
+
delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype)
|
| 68 |
+
|
| 69 |
+
loss = -jump_coefficient * (
|
| 70 |
+
p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
mask = (x_1 != 1).to(loss.dtype) # 1 is the masked token
|
| 74 |
+
loss = loss * mask
|
| 75 |
+
|
| 76 |
+
if self.reduction == "mean":
|
| 77 |
+
return torch.mean(loss)
|
| 78 |
+
elif self.reduction == "sum":
|
| 79 |
+
return torch.sum(loss)
|
| 80 |
+
elif self.reduction == "none":
|
| 81 |
+
return loss
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError(f"{self.reduction} is not a valid value for reduction")
|
flow_matching/path/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .affine import AffineProbPath, CondOTProbPath
|
| 8 |
+
from .geodesic import GeodesicProbPath
|
| 9 |
+
from .mixture import MixtureDiscreteProbPath
|
| 10 |
+
from .path import ProbPath
|
| 11 |
+
from .path_sample import DiscretePathSample, PathSample
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"ProbPath",
|
| 16 |
+
"AffineProbPath",
|
| 17 |
+
"CondOTProbPath",
|
| 18 |
+
"MixtureDiscreteProbPath",
|
| 19 |
+
"GeodesicProbPath",
|
| 20 |
+
"PathSample",
|
| 21 |
+
"DiscretePathSample",
|
| 22 |
+
]
|
flow_matching/path/affine.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from flow_matching.path.path import ProbPath
|
| 10 |
+
from flow_matching.path.path_sample import PathSample
|
| 11 |
+
from flow_matching.path.scheduler.scheduler import CondOTScheduler, Scheduler
|
| 12 |
+
from flow_matching.utils import expand_tensor_like
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AffineProbPath(ProbPath):
|
| 16 |
+
r"""The ``AffineProbPath`` class represents a specific type of probability path where the transformation between distributions is affine.
|
| 17 |
+
An affine transformation can be represented as:
|
| 18 |
+
|
| 19 |
+
.. math::
|
| 20 |
+
|
| 21 |
+
X_t = \alpha_t X_1 + \sigma_t X_0,
|
| 22 |
+
|
| 23 |
+
where :math:`X_t` is the transformed data point at time `t`. :math:`X_0` and :math:`X_1` are the source and target data points, respectively. :math:`\alpha_t` and :math:`\sigma_t` are the parameters of the affine transformation at time `t`.
|
| 24 |
+
|
| 25 |
+
The scheduler is responsible for providing the time-dependent parameters :math:`\alpha_t` and :math:`\sigma_t`, as well as their derivatives, which define the affine transformation at any given time `t`.
|
| 26 |
+
|
| 27 |
+
Using ``AffineProbPath`` in the flow matching framework:
|
| 28 |
+
|
| 29 |
+
.. code-block:: python
|
| 30 |
+
|
| 31 |
+
# Instantiates a probability path
|
| 32 |
+
my_path = AffineProbPath(...)
|
| 33 |
+
mse_loss = torch.nn.MSELoss()
|
| 34 |
+
|
| 35 |
+
for x_1 in dataset:
|
| 36 |
+
# Sets x_0 to random noise
|
| 37 |
+
x_0 = torch.randn()
|
| 38 |
+
|
| 39 |
+
# Sets t to a random value in [0,1]
|
| 40 |
+
t = torch.rand()
|
| 41 |
+
|
| 42 |
+
# Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
|
| 43 |
+
path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
|
| 44 |
+
|
| 45 |
+
# Computes the MSE loss w.r.t. the velocity
|
| 46 |
+
loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
|
| 47 |
+
loss.backward()
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
scheduler (Scheduler): An instance of a scheduler that provides the parameters :math:`\alpha_t`, :math:`\sigma_t`, and their derivatives over time.
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, scheduler: Scheduler):
|
| 55 |
+
self.scheduler = scheduler
|
| 56 |
+
|
| 57 |
+
def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
|
| 58 |
+
r"""Sample from the affine probability path:
|
| 59 |
+
|
| 60 |
+
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
|
| 61 |
+
| return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x_0 (Tensor): source data point, shape (batch_size, ...).
|
| 65 |
+
x_1 (Tensor): target data point, shape (batch_size, ...).
|
| 66 |
+
t (Tensor): times in [0,1], shape (batch_size).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
PathSample: a conditional sample at :math:`X_t \sim p_t`.
|
| 70 |
+
"""
|
| 71 |
+
self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
|
| 72 |
+
|
| 73 |
+
scheduler_output = self.scheduler(t)
|
| 74 |
+
|
| 75 |
+
alpha_t = expand_tensor_like(
|
| 76 |
+
input_tensor=scheduler_output.alpha_t, expand_to=x_1
|
| 77 |
+
)
|
| 78 |
+
sigma_t = expand_tensor_like(
|
| 79 |
+
input_tensor=scheduler_output.sigma_t, expand_to=x_1
|
| 80 |
+
)
|
| 81 |
+
d_alpha_t = expand_tensor_like(
|
| 82 |
+
input_tensor=scheduler_output.d_alpha_t, expand_to=x_1
|
| 83 |
+
)
|
| 84 |
+
d_sigma_t = expand_tensor_like(
|
| 85 |
+
input_tensor=scheduler_output.d_sigma_t, expand_to=x_1
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# construct xt ~ p_t(x|x1).
|
| 89 |
+
x_t = sigma_t * x_0 + alpha_t * x_1
|
| 90 |
+
dx_t = d_sigma_t * x_0 + d_alpha_t * x_1
|
| 91 |
+
|
| 92 |
+
return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)
|
| 93 |
+
|
| 94 |
+
def target_to_velocity(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
|
| 95 |
+
r"""Convert from x_1 representation to velocity.
|
| 96 |
+
|
| 97 |
+
| given :math:`X_1`.
|
| 98 |
+
| return :math:`\dot{X}_t`.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
x_1 (Tensor): target data point.
|
| 102 |
+
x_t (Tensor): path sample at time t.
|
| 103 |
+
t (Tensor): time in [0,1].
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Tensor: velocity.
|
| 107 |
+
"""
|
| 108 |
+
scheduler_output = self.scheduler(t)
|
| 109 |
+
|
| 110 |
+
alpha_t = scheduler_output.alpha_t
|
| 111 |
+
d_alpha_t = scheduler_output.d_alpha_t
|
| 112 |
+
sigma_t = scheduler_output.sigma_t
|
| 113 |
+
d_sigma_t = scheduler_output.d_sigma_t
|
| 114 |
+
|
| 115 |
+
a_t = d_sigma_t / sigma_t
|
| 116 |
+
b_t = (d_alpha_t * sigma_t - d_sigma_t * alpha_t) / sigma_t
|
| 117 |
+
|
| 118 |
+
return a_t * x_t + b_t * x_1
|
| 119 |
+
|
| 120 |
+
def epsilon_to_velocity(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
|
| 121 |
+
r"""Convert from epsilon representation to velocity.
|
| 122 |
+
|
| 123 |
+
| given :math:`\epsilon`.
|
| 124 |
+
| return :math:`\dot{X}_t`.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
epsilon (Tensor): noise in the path sample.
|
| 128 |
+
x_t (Tensor): path sample at time t.
|
| 129 |
+
t (Tensor): time in [0,1].
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Tensor: velocity.
|
| 133 |
+
"""
|
| 134 |
+
scheduler_output = self.scheduler(t)
|
| 135 |
+
|
| 136 |
+
alpha_t = scheduler_output.alpha_t
|
| 137 |
+
d_alpha_t = scheduler_output.d_alpha_t
|
| 138 |
+
sigma_t = scheduler_output.sigma_t
|
| 139 |
+
d_sigma_t = scheduler_output.d_sigma_t
|
| 140 |
+
|
| 141 |
+
a_t = d_alpha_t / alpha_t
|
| 142 |
+
b_t = (d_sigma_t * alpha_t - d_alpha_t * sigma_t) / alpha_t
|
| 143 |
+
|
| 144 |
+
return a_t * x_t + b_t * epsilon
|
| 145 |
+
|
| 146 |
+
def velocity_to_target(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
|
| 147 |
+
r"""Convert from velocity to x_1 representation.
|
| 148 |
+
|
| 149 |
+
| given :math:`\dot{X}_t`.
|
| 150 |
+
| return :math:`X_1`.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
velocity (Tensor): velocity at the path sample.
|
| 154 |
+
x_t (Tensor): path sample at time t.
|
| 155 |
+
t (Tensor): time in [0,1].
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Tensor: target data point.
|
| 159 |
+
"""
|
| 160 |
+
scheduler_output = self.scheduler(t)
|
| 161 |
+
|
| 162 |
+
alpha_t = scheduler_output.alpha_t
|
| 163 |
+
d_alpha_t = scheduler_output.d_alpha_t
|
| 164 |
+
sigma_t = scheduler_output.sigma_t
|
| 165 |
+
d_sigma_t = scheduler_output.d_sigma_t
|
| 166 |
+
|
| 167 |
+
a_t = -d_sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t)
|
| 168 |
+
b_t = sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t)
|
| 169 |
+
|
| 170 |
+
return a_t * x_t + b_t * velocity
|
| 171 |
+
|
| 172 |
+
def epsilon_to_target(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
|
| 173 |
+
r"""Convert from epsilon representation to x_1 representation.
|
| 174 |
+
|
| 175 |
+
| given :math:`\epsilon`.
|
| 176 |
+
| return :math:`X_1`.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
epsilon (Tensor): noise in the path sample.
|
| 180 |
+
x_t (Tensor): path sample at time t.
|
| 181 |
+
t (Tensor): time in [0,1].
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tensor: target data point.
|
| 185 |
+
"""
|
| 186 |
+
scheduler_output = self.scheduler(t)
|
| 187 |
+
|
| 188 |
+
alpha_t = scheduler_output.alpha_t
|
| 189 |
+
sigma_t = scheduler_output.sigma_t
|
| 190 |
+
|
| 191 |
+
a_t = 1 / alpha_t
|
| 192 |
+
b_t = -sigma_t / alpha_t
|
| 193 |
+
|
| 194 |
+
return a_t * x_t + b_t * epsilon
|
| 195 |
+
|
| 196 |
+
def velocity_to_epsilon(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
|
| 197 |
+
r"""Convert from velocity to noise representation.
|
| 198 |
+
|
| 199 |
+
| given :math:`\dot{X}_t`.
|
| 200 |
+
| return :math:`\epsilon`.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
velocity (Tensor): velocity at the path sample.
|
| 204 |
+
x_t (Tensor): path sample at time t.
|
| 205 |
+
t (Tensor): time in [0,1].
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Tensor: noise in the path sample.
|
| 209 |
+
"""
|
| 210 |
+
scheduler_output = self.scheduler(t)
|
| 211 |
+
|
| 212 |
+
alpha_t = scheduler_output.alpha_t
|
| 213 |
+
d_alpha_t = scheduler_output.d_alpha_t
|
| 214 |
+
sigma_t = scheduler_output.sigma_t
|
| 215 |
+
d_sigma_t = scheduler_output.d_sigma_t
|
| 216 |
+
|
| 217 |
+
a_t = -d_alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t)
|
| 218 |
+
b_t = alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t)
|
| 219 |
+
|
| 220 |
+
return a_t * x_t + b_t * velocity
|
| 221 |
+
|
| 222 |
+
def target_to_epsilon(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
|
| 223 |
+
r"""Convert from x_1 representation to velocity.
|
| 224 |
+
|
| 225 |
+
| given :math:`X_1`.
|
| 226 |
+
| return :math:`\epsilon`.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
x_1 (Tensor): target data point.
|
| 230 |
+
x_t (Tensor): path sample at time t.
|
| 231 |
+
t (Tensor): time in [0,1].
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Tensor: noise in the path sample.
|
| 235 |
+
"""
|
| 236 |
+
scheduler_output = self.scheduler(t)
|
| 237 |
+
|
| 238 |
+
alpha_t = scheduler_output.alpha_t
|
| 239 |
+
sigma_t = scheduler_output.sigma_t
|
| 240 |
+
|
| 241 |
+
a_t = 1 / sigma_t
|
| 242 |
+
b_t = -alpha_t / sigma_t
|
| 243 |
+
|
| 244 |
+
return a_t * x_t + b_t * x_1
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class CondOTProbPath(AffineProbPath):
|
| 248 |
+
r"""The ``CondOTProbPath`` class represents a conditional optimal transport probability path.
|
| 249 |
+
|
| 250 |
+
This class is a specialized version of the ``AffineProbPath`` that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.
|
| 251 |
+
|
| 252 |
+
The parameters :math:`\alpha_t` and :math:`\sigma_t` for the conditional optimal transport path are defined as:
|
| 253 |
+
|
| 254 |
+
.. math::
|
| 255 |
+
|
| 256 |
+
\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def __init__(self):
|
| 260 |
+
self.scheduler = CondOTScheduler()
|
flow_matching/path/geodesic.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.func import jvp, vmap
|
| 11 |
+
|
| 12 |
+
from flow_matching.path.path import ProbPath
|
| 13 |
+
|
| 14 |
+
from flow_matching.path.path_sample import PathSample
|
| 15 |
+
from flow_matching.path.scheduler import ConvexScheduler
|
| 16 |
+
from flow_matching.utils import expand_tensor_like
|
| 17 |
+
|
| 18 |
+
from flow_matching.utils.manifolds import geodesic, Manifold
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class GeodesicProbPath(ProbPath):
|
| 22 |
+
r"""The ``GeodesicProbPath`` class represents a specific type of probability path where the transformation between distributions is defined through the geodesic path.
|
| 23 |
+
Mathematically, a geodesic path can be represented as:
|
| 24 |
+
|
| 25 |
+
.. math::
|
| 26 |
+
|
| 27 |
+
X_t = \psi_t(X_0 | X_1) = \exp_{X_1}(\kappa_t \log_{X_1}(X_0)),
|
| 28 |
+
|
| 29 |
+
where :math:`X_t` is the transformed data point at time `t`, :math:`X_0` and :math:`X_1` are the source and target data points, respectively, and :math:`\kappa_t` is a scheduler.
|
| 30 |
+
|
| 31 |
+
The scheduler is responsible for providing the time-dependent :math:`\kappa_t` and must be differentiable.
|
| 32 |
+
|
| 33 |
+
Using ``GeodesicProbPath`` in the flow matching framework:
|
| 34 |
+
|
| 35 |
+
.. code-block:: python
|
| 36 |
+
# Instantiates a manifold
|
| 37 |
+
manifold = FlatTorus()
|
| 38 |
+
|
| 39 |
+
# Instantiates a scheduler
|
| 40 |
+
scheduler = CondOTScheduler()
|
| 41 |
+
|
| 42 |
+
# Instantiates a probability path
|
| 43 |
+
my_path = GeodesicProbPath(scheduler, manifold)
|
| 44 |
+
mse_loss = torch.nn.MSELoss()
|
| 45 |
+
|
| 46 |
+
for x_1 in dataset:
|
| 47 |
+
# Sets x_0 to random noise
|
| 48 |
+
x_0 = torch.randn()
|
| 49 |
+
|
| 50 |
+
# Sets t to a random value in [0,1]
|
| 51 |
+
t = torch.rand()
|
| 52 |
+
|
| 53 |
+
# Samples the conditional path :math:`X_t \sim p_t(X_t|X_0,X_1)`
|
| 54 |
+
path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
|
| 55 |
+
|
| 56 |
+
# Computes the MSE loss w.r.t. the velocity
|
| 57 |
+
loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
|
| 58 |
+
loss.backward()
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
scheduler (ConvexScheduler): The scheduler that provides :math:`\kappa_t`.
|
| 62 |
+
manifold (Manifold): The manifold on which the probability path is defined.
|
| 63 |
+
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, scheduler: ConvexScheduler, manifold: Manifold):
|
| 67 |
+
self.scheduler = scheduler
|
| 68 |
+
self.manifold = manifold
|
| 69 |
+
|
| 70 |
+
def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
|
| 71 |
+
r"""Sample from the Riemannian probability path with geodesic interpolation:
|
| 72 |
+
|
| 73 |
+
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`\kappa_t`.
|
| 74 |
+
| return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
x_0 (Tensor): source data point, shape (batch_size, ...).
|
| 78 |
+
x_1 (Tensor): target data point, shape (batch_size, ...).
|
| 79 |
+
t (Tensor): times in [0,1], shape (batch_size).
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
PathSample: A conditional sample at :math:`X_t \sim p_t`.
|
| 83 |
+
"""
|
| 84 |
+
self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
|
| 85 |
+
t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone()
|
| 86 |
+
|
| 87 |
+
def cond_u(x_0, x_1, t):
|
| 88 |
+
path = geodesic(self.manifold, x_0, x_1)
|
| 89 |
+
x_t, dx_t = jvp(
|
| 90 |
+
lambda t: path(self.scheduler(t).alpha_t),
|
| 91 |
+
(t,),
|
| 92 |
+
(torch.ones_like(t).to(t),),
|
| 93 |
+
)
|
| 94 |
+
return x_t, dx_t
|
| 95 |
+
|
| 96 |
+
x_t, dx_t = vmap(cond_u)(x_0, x_1, t)
|
| 97 |
+
x_t = x_t.reshape_as(x_1)
|
| 98 |
+
dx_t = dx_t.reshape_as(x_1)
|
| 99 |
+
|
| 100 |
+
return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)
|
flow_matching/path/mixture.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
from flow_matching.path.path import ProbPath
|
| 13 |
+
|
| 14 |
+
from flow_matching.path.path_sample import DiscretePathSample
|
| 15 |
+
from flow_matching.path.scheduler import ConvexScheduler
|
| 16 |
+
from flow_matching.utils import expand_tensor_like, unsqueeze_to_match
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MixtureDiscreteProbPath(ProbPath):
|
| 20 |
+
r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path.
|
| 21 |
+
|
| 22 |
+
This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`.
|
| 23 |
+
The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`:
|
| 24 |
+
|
| 25 |
+
.. math::
|
| 26 |
+
|
| 27 |
+
P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t,
|
| 28 |
+
|
| 29 |
+
where :math:`\sigma_t` is provided by the scheduler.
|
| 30 |
+
|
| 31 |
+
Example:
|
| 32 |
+
|
| 33 |
+
.. code-block:: python
|
| 34 |
+
|
| 35 |
+
>>> x_0 = torch.zeros((1, 3, 3))
|
| 36 |
+
>>> x_1 = torch.ones((1, 3, 3))
|
| 37 |
+
|
| 38 |
+
>>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0))
|
| 39 |
+
>>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t
|
| 40 |
+
>>> result
|
| 41 |
+
tensor([[[0.0, 0.0, 0.0],
|
| 42 |
+
[0.0, 0.0, 1.0],
|
| 43 |
+
[0.0, 0.0, 0.0]]])
|
| 44 |
+
|
| 45 |
+
>>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t
|
| 46 |
+
>>> result
|
| 47 |
+
tensor([[[1.0, 0.0, 1.0],
|
| 48 |
+
[0.0, 1.0, 0.0],
|
| 49 |
+
[0.0, 1.0, 0.0]]])
|
| 50 |
+
|
| 51 |
+
>>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t
|
| 52 |
+
>>> result
|
| 53 |
+
tensor([[[1.0, 1.0, 1.0],
|
| 54 |
+
[1.0, 1.0, 1.0],
|
| 55 |
+
[1.0, 1.0, 1.0]]])
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, scheduler: ConvexScheduler):
|
| 62 |
+
assert isinstance(
|
| 63 |
+
scheduler, ConvexScheduler
|
| 64 |
+
), "Scheduler for ConvexProbPath must be a ConvexScheduler."
|
| 65 |
+
|
| 66 |
+
self.scheduler = scheduler
|
| 67 |
+
|
| 68 |
+
def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
|
| 69 |
+
r"""Sample from the affine probability path:
|
| 70 |
+
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
|
| 71 |
+
| return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`.
|
| 72 |
+
Args:
|
| 73 |
+
x_0 (Tensor): source data point, shape (batch_size, ...).
|
| 74 |
+
x_1 (Tensor): target data point, shape (batch_size, ...).
|
| 75 |
+
t (Tensor): times in [0,1], shape (batch_size).
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`.
|
| 79 |
+
"""
|
| 80 |
+
self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
|
| 81 |
+
|
| 82 |
+
sigma_t = self.scheduler(t).sigma_t
|
| 83 |
+
|
| 84 |
+
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)
|
| 85 |
+
|
| 86 |
+
source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
|
| 87 |
+
x_t = torch.where(condition=source_indices, input=x_0, other=x_1)
|
| 88 |
+
|
| 89 |
+
return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t)
|
| 90 |
+
|
| 91 |
+
def posterior_to_velocity(
|
| 92 |
+
self, posterior_logits: Tensor, x_t: Tensor, t: Tensor
|
| 93 |
+
) -> Tensor:
|
| 94 |
+
r"""Convert the factorized posterior to velocity.
|
| 95 |
+
|
| 96 |
+
| given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`.
|
| 97 |
+
| return :math:`u_t`.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size).
|
| 101 |
+
x_t (Tensor): path sample at time t, shape (...).
|
| 102 |
+
t (Tensor): time in [0,1].
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Tensor: velocity.
|
| 106 |
+
"""
|
| 107 |
+
posterior = torch.softmax(posterior_logits, dim=-1)
|
| 108 |
+
vocabulary_size = posterior.shape[-1]
|
| 109 |
+
x_t = F.one_hot(x_t, num_classes=vocabulary_size)
|
| 110 |
+
t = unsqueeze_to_match(source=t, target=x_t)
|
| 111 |
+
|
| 112 |
+
scheduler_output = self.scheduler(t)
|
| 113 |
+
|
| 114 |
+
kappa_t = scheduler_output.alpha_t
|
| 115 |
+
d_kappa_t = scheduler_output.d_alpha_t
|
| 116 |
+
|
| 117 |
+
return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t)
|
flow_matching/path/path.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from flow_matching.path.path_sample import PathSample
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ProbPath(ABC):
|
| 15 |
+
r"""Abstract class, representing a probability path.
|
| 16 |
+
|
| 17 |
+
A probability path transforms the distribution :math:`p(X_0)` into :math:`p(X_1)` over :math:`t=0\rightarrow 1`.
|
| 18 |
+
|
| 19 |
+
The ``ProbPath`` class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives.
|
| 20 |
+
Here is a high-level example
|
| 21 |
+
|
| 22 |
+
.. code-block:: python
|
| 23 |
+
|
| 24 |
+
# Instantiate a probability path
|
| 25 |
+
my_path = ProbPath(...)
|
| 26 |
+
|
| 27 |
+
for x_0, x_1 in dataset:
|
| 28 |
+
# Sets t to a random value in [0,1]
|
| 29 |
+
t = torch.rand()
|
| 30 |
+
|
| 31 |
+
# Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
|
| 32 |
+
path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
|
| 33 |
+
|
| 34 |
+
# Optimizes the model. The loss function varies, depending on model and path.
|
| 35 |
+
loss(path_sample, my_model(x_t, t)).backward()
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
|
| 41 |
+
r"""Sample from an abstract probability path:
|
| 42 |
+
|
| 43 |
+
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)`.
|
| 44 |
+
| returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
x_0 (Tensor): source data point, shape (batch_size, ...).
|
| 48 |
+
x_1 (Tensor): target data point, shape (batch_size, ...).
|
| 49 |
+
t (Tensor): times in [0,1], shape (batch_size).
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
PathSample: a conditional sample.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor):
|
| 56 |
+
assert (
|
| 57 |
+
t.ndim == 1
|
| 58 |
+
), f"The time vector t must have shape [batch_size]. Got {t.shape}."
|
| 59 |
+
assert (
|
| 60 |
+
t.shape[0] == x_0.shape[0] == x_1.shape[0]
|
| 61 |
+
), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}"
|
flow_matching/path/path_sample.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class PathSample:
|
| 14 |
+
r"""Represents a sample of a conditional-flow generated probability path.
|
| 15 |
+
|
| 16 |
+
Attributes:
|
| 17 |
+
x_1 (Tensor): the target sample :math:`X_1`.
|
| 18 |
+
x_0 (Tensor): the source sample :math:`X_0`.
|
| 19 |
+
t (Tensor): the time sample :math:`t`.
|
| 20 |
+
x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...).
|
| 21 |
+
dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...).
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
|
| 26 |
+
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
|
| 27 |
+
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
|
| 28 |
+
x_t: Tensor = field(
|
| 29 |
+
metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."}
|
| 30 |
+
)
|
| 31 |
+
dx_t: Tensor = field(
|
| 32 |
+
metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."}
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class DiscretePathSample:
|
| 38 |
+
"""
|
| 39 |
+
Represents a sample of a conditional-flow generated discrete probability path.
|
| 40 |
+
|
| 41 |
+
Attributes:
|
| 42 |
+
x_1 (Tensor): the target sample :math:`X_1`.
|
| 43 |
+
x_0 (Tensor): the source sample :math:`X_0`.
|
| 44 |
+
t (Tensor): the time sample :math:`t`.
|
| 45 |
+
x_t (Tensor): the sample along the path :math:`X_t \sim p_t`.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
|
| 49 |
+
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
|
| 50 |
+
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
|
| 51 |
+
x_t: Tensor = field(
|
| 52 |
+
metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."}
|
| 53 |
+
)
|
flow_matching/path/scheduler/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .schedule_transform import ScheduleTransformedModel
|
| 8 |
+
from .scheduler import (
|
| 9 |
+
CondOTScheduler,
|
| 10 |
+
ConvexScheduler,
|
| 11 |
+
CosineScheduler,
|
| 12 |
+
LinearVPScheduler,
|
| 13 |
+
PolynomialConvexScheduler,
|
| 14 |
+
Scheduler,
|
| 15 |
+
SchedulerOutput,
|
| 16 |
+
VPScheduler,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"CondOTScheduler",
|
| 21 |
+
"CosineScheduler",
|
| 22 |
+
"ConvexScheduler",
|
| 23 |
+
"PolynomialConvexScheduler",
|
| 24 |
+
"ScheduleTransformedModel",
|
| 25 |
+
"Scheduler",
|
| 26 |
+
"VPScheduler",
|
| 27 |
+
"LinearVPScheduler",
|
| 28 |
+
"SchedulerOutput",
|
| 29 |
+
]
|
flow_matching/path/scheduler/schedule_transform.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from flow_matching.path.scheduler.scheduler import Scheduler
|
| 10 |
+
from flow_matching.utils import ModelWrapper
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ScheduleTransformedModel(ModelWrapper):
|
| 14 |
+
"""
|
| 15 |
+
Change of scheduler for a velocity model.
|
| 16 |
+
|
| 17 |
+
This class wraps a given velocity model and transforms its scheduling
|
| 18 |
+
to a new scheduler function. It modifies the time
|
| 19 |
+
dynamics of the model according to the new scheduler while maintaining
|
| 20 |
+
the original model's behavior.
|
| 21 |
+
|
| 22 |
+
Example:
|
| 23 |
+
|
| 24 |
+
.. code-block:: python
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
|
| 28 |
+
from flow_matching.solver import ODESolver
|
| 29 |
+
|
| 30 |
+
# Initialize the model and schedulers
|
| 31 |
+
model = ...
|
| 32 |
+
|
| 33 |
+
original_scheduler = CondOTScheduler()
|
| 34 |
+
new_scheduler = CosineScheduler()
|
| 35 |
+
|
| 36 |
+
# Create the transformed model
|
| 37 |
+
transformed_model = ScheduleTransformedModel(
|
| 38 |
+
velocity_model=model,
|
| 39 |
+
original_scheduler=original_scheduler,
|
| 40 |
+
new_scheduler=new_scheduler
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Set up the solver
|
| 44 |
+
solver = ODESolver(velocity_model=transformed_model)
|
| 45 |
+
|
| 46 |
+
x_0 = torch.randn([10, 2]) # Example initial condition
|
| 47 |
+
|
| 48 |
+
x_1 = solver.sample(
|
| 49 |
+
time_steps=torch.tensor([0.0, 1.0]),
|
| 50 |
+
x_init=x_0,
|
| 51 |
+
step_size=1/1000
|
| 52 |
+
)[1]
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
velocity_model (ModelWrapper): The original velocity model to be transformed.
|
| 56 |
+
original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function.
|
| 57 |
+
new_scheduler (Scheduler): The new scheduler to be applied to the model.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
velocity_model: ModelWrapper,
|
| 63 |
+
original_scheduler: Scheduler,
|
| 64 |
+
new_scheduler: Scheduler,
|
| 65 |
+
):
|
| 66 |
+
super().__init__(model=velocity_model)
|
| 67 |
+
self.original_scheduler = original_scheduler
|
| 68 |
+
self.new_scheduler = new_scheduler
|
| 69 |
+
|
| 70 |
+
assert hasattr(self.original_scheduler, "snr_inverse") and callable(
|
| 71 |
+
getattr(self.original_scheduler, "snr_inverse")
|
| 72 |
+
), "The original scheduler must have a callable 'snr_inverse' method."
|
| 73 |
+
|
| 74 |
+
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
|
| 75 |
+
r"""
|
| 76 |
+
Compute the transformed marginal velocity field for a new scheduler.
|
| 77 |
+
This method implements a post-training velocity scheduler change for
|
| 78 |
+
affine conditional flows. It transforms a generating marginal velocity
|
| 79 |
+
field :math:`u_t(x)` based on an original scheduler to a new marginal velocity
|
| 80 |
+
field :math:`\bar{u}_r(x)` based on a different scheduler, while maintaining
|
| 81 |
+
the same data coupling.
|
| 82 |
+
The transformation is based on the scale-time (ST) transformation
|
| 83 |
+
between the two conditional flows, defined as:
|
| 84 |
+
|
| 85 |
+
.. math::
|
| 86 |
+
|
| 87 |
+
\bar{X}_r = s_r X_{t_r},
|
| 88 |
+
|
| 89 |
+
where :math:`X_t` and :math:`\bar{X}_r` are defined by their respective schedulers.
|
| 90 |
+
The ST transformation is computed as:
|
| 91 |
+
|
| 92 |
+
.. math::
|
| 93 |
+
|
| 94 |
+
t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}.
|
| 95 |
+
|
| 96 |
+
Here, :math:`\rho(t)` is the signal-to-noise ratio (SNR) defined as:
|
| 97 |
+
|
| 98 |
+
.. math::
|
| 99 |
+
|
| 100 |
+
\rho(t) = \frac{\alpha_t}{\sigma_t}.
|
| 101 |
+
|
| 102 |
+
:math:`\bar{\rho}(r)` is similarly defined for the new scheduler.
|
| 103 |
+
The marginal velocity for the new scheduler is then given by:
|
| 104 |
+
|
| 105 |
+
.. math::
|
| 106 |
+
|
| 107 |
+
\bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right).
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
x (Tensor): :math:`x_t`, the input tensor.
|
| 111 |
+
t (Tensor): The time tensor (denoted as :math:`r` above).
|
| 112 |
+
**extras: Additional arguments for the model.
|
| 113 |
+
Returns:
|
| 114 |
+
Tensor: The transformed velocity.
|
| 115 |
+
"""
|
| 116 |
+
r = t
|
| 117 |
+
|
| 118 |
+
r_scheduler_output = self.new_scheduler(t=r)
|
| 119 |
+
|
| 120 |
+
alpha_r = r_scheduler_output.alpha_t
|
| 121 |
+
sigma_r = r_scheduler_output.sigma_t
|
| 122 |
+
d_alpha_r = r_scheduler_output.d_alpha_t
|
| 123 |
+
d_sigma_r = r_scheduler_output.d_sigma_t
|
| 124 |
+
|
| 125 |
+
t = self.original_scheduler.snr_inverse(alpha_r / sigma_r)
|
| 126 |
+
|
| 127 |
+
t_scheduler_output = self.original_scheduler(t=t)
|
| 128 |
+
|
| 129 |
+
alpha_t = t_scheduler_output.alpha_t
|
| 130 |
+
sigma_t = t_scheduler_output.sigma_t
|
| 131 |
+
d_alpha_t = t_scheduler_output.d_alpha_t
|
| 132 |
+
d_sigma_t = t_scheduler_output.d_sigma_t
|
| 133 |
+
|
| 134 |
+
s_r = sigma_r / sigma_t
|
| 135 |
+
|
| 136 |
+
dt_r = (
|
| 137 |
+
sigma_t
|
| 138 |
+
* sigma_t
|
| 139 |
+
* (sigma_r * d_alpha_r - alpha_r * d_sigma_r)
|
| 140 |
+
/ (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t))
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t)
|
| 144 |
+
|
| 145 |
+
u_t = self.model(x=x / s_r, t=t, **extras)
|
| 146 |
+
u_r = ds_r * x / s_r + dt_r * s_r * u_t
|
| 147 |
+
|
| 148 |
+
return u_r
|
flow_matching/path/scheduler/scheduler.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
|
| 10 |
+
from typing import Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class SchedulerOutput:
|
| 19 |
+
r"""Represents a sample of a conditional-flow generated probability path.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
alpha_t (Tensor): :math:`\alpha_t`, shape (...).
|
| 23 |
+
sigma_t (Tensor): :math:`\sigma_t`, shape (...).
|
| 24 |
+
d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...).
|
| 25 |
+
d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...).
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
alpha_t: Tensor = field(metadata={"help": "alpha_t"})
|
| 30 |
+
sigma_t: Tensor = field(metadata={"help": "sigma_t"})
|
| 31 |
+
d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."})
|
| 32 |
+
d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."})
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Scheduler(ABC):
|
| 36 |
+
"""Base Scheduler class."""
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def __call__(self, t: Tensor) -> SchedulerOutput:
|
| 40 |
+
r"""
|
| 41 |
+
Args:
|
| 42 |
+
t (Tensor): times in [0,1], shape (...).
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
|
| 46 |
+
"""
|
| 47 |
+
...
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def snr_inverse(self, snr: Tensor) -> Tensor:
|
| 51 |
+
r"""
|
| 52 |
+
Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
snr (Tensor): The signal-to-noise, shape (...)
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Tensor: t, shape (...)
|
| 59 |
+
"""
|
| 60 |
+
...
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ConvexScheduler(Scheduler):
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def __call__(self, t: Tensor) -> SchedulerOutput:
|
| 66 |
+
"""Scheduler for convex paths.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
t (Tensor): times in [0,1], shape (...).
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
|
| 73 |
+
"""
|
| 74 |
+
...
|
| 75 |
+
|
| 76 |
+
@abstractmethod
|
| 77 |
+
def kappa_inverse(self, kappa: Tensor) -> Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Computes :math:`t` from :math:`\kappa_t`.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
kappa (Tensor): :math:`\kappa`, shape (...)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Tensor: t, shape (...)
|
| 86 |
+
"""
|
| 87 |
+
...
|
| 88 |
+
|
| 89 |
+
def snr_inverse(self, snr: Tensor) -> Tensor:
|
| 90 |
+
r"""
|
| 91 |
+
Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
snr (Tensor): The signal-to-noise, shape (...)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Tensor: t, shape (...)
|
| 98 |
+
"""
|
| 99 |
+
kappa_t = snr / (1.0 + snr)
|
| 100 |
+
|
| 101 |
+
return self.kappa_inverse(kappa=kappa_t)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class CondOTScheduler(ConvexScheduler):
|
| 105 |
+
"""CondOT Scheduler."""
|
| 106 |
+
|
| 107 |
+
def __call__(self, t: Tensor) -> SchedulerOutput:
|
| 108 |
+
return SchedulerOutput(
|
| 109 |
+
alpha_t=t,
|
| 110 |
+
sigma_t=1 - t,
|
| 111 |
+
d_alpha_t=torch.ones_like(t),
|
| 112 |
+
d_sigma_t=-torch.ones_like(t),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def kappa_inverse(self, kappa: Tensor) -> Tensor:
|
| 116 |
+
return kappa
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class PolynomialConvexScheduler(ConvexScheduler):
|
| 120 |
+
"""Polynomial Scheduler."""
|
| 121 |
+
|
| 122 |
+
def __init__(self, n: Union[float, int]) -> None:
|
| 123 |
+
assert isinstance(
|
| 124 |
+
n, (float, int)
|
| 125 |
+
), f"`n` must be a float or int. Got {type(n)=}."
|
| 126 |
+
assert n > 0, f"`n` must be positive. Got {n=}."
|
| 127 |
+
|
| 128 |
+
self.n = n
|
| 129 |
+
|
| 130 |
+
def __call__(self, t: Tensor) -> SchedulerOutput:
|
| 131 |
+
return SchedulerOutput(
|
| 132 |
+
alpha_t=t**self.n,
|
| 133 |
+
sigma_t=1 - t**self.n,
|
| 134 |
+
d_alpha_t=self.n * (t ** (self.n - 1)),
|
| 135 |
+
d_sigma_t=-self.n * (t ** (self.n - 1)),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def kappa_inverse(self, kappa: Tensor) -> Tensor:
|
| 139 |
+
return torch.pow(kappa, 1.0 / self.n)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class VPScheduler(Scheduler):
|
| 143 |
+
"""Variance Preserving Scheduler."""
|
| 144 |
+
|
| 145 |
+
def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None:
|
| 146 |
+
self.beta_min = beta_min
|
| 147 |
+
self.beta_max = beta_max
|
| 148 |
+
super().__init__()
|
| 149 |
+
|
| 150 |
+
def __call__(self, t: Tensor) -> SchedulerOutput:
|
| 151 |
+
b = self.beta_min
|
| 152 |
+
B = self.beta_max
|
| 153 |
+
T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b
|
| 154 |
+
dT = -(1 - t) * (B - b) - b
|
| 155 |
+
|
| 156 |
+
return SchedulerOutput(
|
| 157 |
+
alpha_t=torch.exp(-0.5 * T),
|
| 158 |
+
sigma_t=torch.sqrt(1 - torch.exp(-T)),
|
| 159 |
+
d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T),
|
| 160 |
+
d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def snr_inverse(self, snr: Tensor) -> Tensor:
|
| 164 |
+
T = -torch.log(snr**2 / (snr**2 + 1))
|
| 165 |
+
b = self.beta_min
|
| 166 |
+
B = self.beta_max
|
| 167 |
+
t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b))
|
| 168 |
+
return t
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class LinearVPScheduler(Scheduler):
|
| 172 |
+
"""Linear Variance Preserving Scheduler."""
|
| 173 |
+
|
| 174 |
+
def __call__(self, t: Tensor) -> SchedulerOutput:
|
| 175 |
+
return SchedulerOutput(
|
| 176 |
+
alpha_t=t,
|
| 177 |
+
sigma_t=(1 - t**2) ** 0.5,
|
| 178 |
+
d_alpha_t=torch.ones_like(t),
|
| 179 |
+
d_sigma_t=-t / (1 - t**2) ** 0.5,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def snr_inverse(self, snr: Tensor) -> Tensor:
|
| 183 |
+
return torch.sqrt(snr**2 / (1 + snr**2))
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class CosineScheduler(Scheduler):
|
| 187 |
+
"""Cosine Scheduler."""
|
| 188 |
+
|
| 189 |
+
def __call__(self, t: Tensor) -> SchedulerOutput:
|
| 190 |
+
pi = torch.pi
|
| 191 |
+
return SchedulerOutput(
|
| 192 |
+
alpha_t=torch.sin(pi / 2 * t),
|
| 193 |
+
sigma_t=torch.cos(pi / 2 * t),
|
| 194 |
+
d_alpha_t=pi / 2 * torch.cos(pi / 2 * t),
|
| 195 |
+
d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def snr_inverse(self, snr: Tensor) -> Tensor:
|
| 199 |
+
return 2.0 * torch.atan(snr) / torch.pi
|
flow_matching/solver/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .discrete_solver import MixtureDiscreteEulerSolver
|
| 8 |
+
from .ode_solver import ODESolver
|
| 9 |
+
from .riemannian_ode_solver import RiemannianODESolver
|
| 10 |
+
from .solver import Solver
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"ODESolver",
|
| 14 |
+
"Solver",
|
| 15 |
+
"ModelWrapper",
|
| 16 |
+
"MixtureDiscreteEulerSolver",
|
| 17 |
+
"RiemannianODESolver",
|
| 18 |
+
]
|
flow_matching/solver/discrete_solver.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from math import ceil
|
| 9 |
+
from typing import Callable, Optional, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import gc
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
|
| 16 |
+
from flow_matching.path import MixtureDiscreteProbPath
|
| 17 |
+
|
| 18 |
+
from flow_matching.solver.solver import Solver
|
| 19 |
+
from flow_matching.utils import categorical, ModelWrapper
|
| 20 |
+
from .utils import get_nearest_times
|
| 21 |
+
from ..utils.multi_guidance import *
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
TQDM_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
TQDM_AVAILABLE = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MixtureDiscreteEulerSolver(Solver):
|
| 32 |
+
r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``.
|
| 33 |
+
Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is:
|
| 34 |
+
|
| 35 |
+
.. math::
|
| 36 |
+
|
| 37 |
+
\begin{align*}
|
| 38 |
+
& X_1^i \sim p_{1|t}^i(\cdot|X_t)\\
|
| 39 |
+
& \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\
|
| 40 |
+
& Z^i_{\text{change}} \sim U[0,1]\\
|
| 41 |
+
& X_{t+h}^i \sim \begin{cases}
|
| 42 |
+
\frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\
|
| 43 |
+
\delta_{X_t^i}(\cdot) \text{ else }
|
| 44 |
+
\end{cases}
|
| 45 |
+
\end{align*}
|
| 46 |
+
|
| 47 |
+
Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is:
|
| 48 |
+
|
| 49 |
+
.. math::
|
| 50 |
+
|
| 51 |
+
u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right],
|
| 52 |
+
|
| 53 |
+
where
|
| 54 |
+
|
| 55 |
+
.. math::
|
| 56 |
+
\hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right],
|
| 57 |
+
|
| 58 |
+
and
|
| 59 |
+
|
| 60 |
+
.. math::
|
| 61 |
+
|
| 62 |
+
\check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right].
|
| 63 |
+
|
| 64 |
+
The source distribution :math:`p(x^i)` is given by ``p``.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size].
|
| 68 |
+
path (MixtureDiscreteProbPath): Probability path used for x-prediction training.
|
| 69 |
+
vocabulary_size (int): size of the discrete vocabulary.
|
| 70 |
+
source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
model: ModelWrapper,
|
| 76 |
+
path: MixtureDiscreteProbPath,
|
| 77 |
+
vocabulary_size: int,
|
| 78 |
+
source_distribution_p: Optional[Tensor] = None,
|
| 79 |
+
):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.model = model
|
| 82 |
+
self.path = path
|
| 83 |
+
self.vocabulary_size = vocabulary_size
|
| 84 |
+
|
| 85 |
+
if source_distribution_p is not None:
|
| 86 |
+
assert source_distribution_p.shape == torch.Size(
|
| 87 |
+
[vocabulary_size]
|
| 88 |
+
), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}."
|
| 89 |
+
|
| 90 |
+
self.source_distribution_p = source_distribution_p
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def sample(
|
| 94 |
+
self,
|
| 95 |
+
x_init: Tensor,
|
| 96 |
+
step_size: Optional[float],
|
| 97 |
+
div_free: Union[float, Callable[[float], float]] = 0.0,
|
| 98 |
+
dtype_categorical: torch.dtype = torch.float32,
|
| 99 |
+
time_grid: Tensor = torch.tensor([0.0, 1.0]),
|
| 100 |
+
return_intermediates: bool = False,
|
| 101 |
+
verbose: bool = False,
|
| 102 |
+
**model_extras,
|
| 103 |
+
) -> Tensor:
|
| 104 |
+
"""
|
| 105 |
+
Sample a sequence of discrete values from the given model.
|
| 106 |
+
|
| 107 |
+
.. code-block:: python
|
| 108 |
+
|
| 109 |
+
import torch
|
| 110 |
+
from flow_matching.utils import ModelWrapper
|
| 111 |
+
from flow_matching.solver import MixtureDiscreteEulerSolver
|
| 112 |
+
|
| 113 |
+
class DummyModel(ModelWrapper):
|
| 114 |
+
def __init__(self):
|
| 115 |
+
super().__init__(None)
|
| 116 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
|
| 117 |
+
return ...
|
| 118 |
+
|
| 119 |
+
model = DummyModel()
|
| 120 |
+
solver = MixtureDiscreteEulerSolver(model=model)
|
| 121 |
+
|
| 122 |
+
x_init = torch.LongTensor([122, 725])
|
| 123 |
+
step_size = 0.001
|
| 124 |
+
time_grid = torch.tensor([0.0, 1.0])
|
| 125 |
+
|
| 126 |
+
result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
x_init (Tensor): The initial state.
|
| 130 |
+
step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid.
|
| 131 |
+
div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0.
|
| 132 |
+
dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32.
|
| 133 |
+
time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
|
| 134 |
+
return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False.
|
| 135 |
+
verbose (bool): Whether to print progress bars. Defaults to False.
|
| 136 |
+
**model_extras: Additional input for the model.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Tensor: The sampled sequence of discrete values.
|
| 140 |
+
|
| 141 |
+
Raises:
|
| 142 |
+
ImportError: To run in verbose mode, tqdm must be installed.
|
| 143 |
+
"""
|
| 144 |
+
if not div_free == 0.0:
|
| 145 |
+
assert (
|
| 146 |
+
self.source_distribution_p is not None
|
| 147 |
+
), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity."
|
| 148 |
+
|
| 149 |
+
# Initialize the current state `x_t` with the initial state `X_0`.
|
| 150 |
+
time_grid = time_grid.to(device=x_init.device)
|
| 151 |
+
|
| 152 |
+
if step_size is None:
|
| 153 |
+
# If step_size is None then set the t discretization to time_grid.
|
| 154 |
+
t_discretization = time_grid
|
| 155 |
+
n_steps = len(time_grid) - 1
|
| 156 |
+
else:
|
| 157 |
+
# If step_size is float then t discretization is uniform with step size set by step_size.
|
| 158 |
+
t_init = time_grid[0].item()
|
| 159 |
+
t_final = time_grid[-1].item()
|
| 160 |
+
assert (
|
| 161 |
+
t_final - t_init
|
| 162 |
+
) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
|
| 163 |
+
|
| 164 |
+
n_steps = ceil((t_final - t_init) / step_size)
|
| 165 |
+
t_discretization = torch.tensor(
|
| 166 |
+
[t_init + step_size * i for i in range(n_steps)] + [t_final],
|
| 167 |
+
device=x_init.device,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if return_intermediates:
|
| 171 |
+
# get order of intermediate steps:
|
| 172 |
+
order = torch.argsort(time_grid)
|
| 173 |
+
# Compute intermediate steps to return via nearest points in t_discretization to time_grid.
|
| 174 |
+
time_grid = get_nearest_times(
|
| 175 |
+
time_grid=time_grid, t_discretization=t_discretization
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
x_t = x_init.clone()
|
| 179 |
+
steps_counter = 0
|
| 180 |
+
res = []
|
| 181 |
+
|
| 182 |
+
if return_intermediates:
|
| 183 |
+
res = [x_init.clone()]
|
| 184 |
+
|
| 185 |
+
if verbose:
|
| 186 |
+
if not TQDM_AVAILABLE:
|
| 187 |
+
raise ImportError(
|
| 188 |
+
"tqdm is required for verbose mode. Please install it."
|
| 189 |
+
)
|
| 190 |
+
ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
|
| 191 |
+
else:
|
| 192 |
+
ctx = nullcontext()
|
| 193 |
+
|
| 194 |
+
with ctx:
|
| 195 |
+
for i in range(n_steps):
|
| 196 |
+
t = t_discretization[i : i + 1]
|
| 197 |
+
h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1]
|
| 198 |
+
|
| 199 |
+
# Sample x_1 ~ p_1|t( \cdot |x_t)
|
| 200 |
+
p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras)
|
| 201 |
+
x_1 = categorical(p_1t.to(dtype=dtype_categorical))
|
| 202 |
+
|
| 203 |
+
# Checks if final step
|
| 204 |
+
if i == n_steps - 1:
|
| 205 |
+
x_t = x_1
|
| 206 |
+
else:
|
| 207 |
+
# Compute u_t(x|x_t,x_1)
|
| 208 |
+
scheduler_output = self.path.scheduler(t=t)
|
| 209 |
+
|
| 210 |
+
k_t = scheduler_output.alpha_t
|
| 211 |
+
d_k_t = scheduler_output.d_alpha_t
|
| 212 |
+
|
| 213 |
+
delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to(
|
| 214 |
+
k_t.dtype
|
| 215 |
+
) # [B, L, V]
|
| 216 |
+
u = d_k_t / (1 - k_t) * delta_1
|
| 217 |
+
|
| 218 |
+
# Add divergence-free part
|
| 219 |
+
div_free_t = div_free(t) if callable(div_free) else div_free
|
| 220 |
+
|
| 221 |
+
if div_free_t > 0:
|
| 222 |
+
p_0 = self.source_distribution_p[(None,) * x_t.dim()]
|
| 223 |
+
u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * (
|
| 224 |
+
(1 - k_t) * p_0 + k_t * delta_1
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Set u_t(x_t|x_t,x_1) = 0
|
| 228 |
+
delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) # [B, L, V]
|
| 229 |
+
u = torch.where(
|
| 230 |
+
delta_t.to(dtype=torch.bool), torch.zeros_like(u), u
|
| 231 |
+
)
|
| 232 |
+
# import pdb
|
| 233 |
+
# if i % 10 == 0:
|
| 234 |
+
# pdb.set_trace()
|
| 235 |
+
# Sample x_t ~ u_t( \cdot |x_t,x_1)
|
| 236 |
+
intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0
|
| 237 |
+
mask_jump = torch.rand(size=x_t.shape, device=x_t.device) < 1 - torch.exp(-h * intensity)
|
| 238 |
+
|
| 239 |
+
if mask_jump.sum() > 0:
|
| 240 |
+
x_t[mask_jump] = categorical(
|
| 241 |
+
u[mask_jump].to(dtype=dtype_categorical)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
steps_counter += 1
|
| 245 |
+
t = t + h
|
| 246 |
+
|
| 247 |
+
if return_intermediates and (t in time_grid):
|
| 248 |
+
res.append(x_t.clone())
|
| 249 |
+
|
| 250 |
+
if verbose:
|
| 251 |
+
ctx.n = t.item()
|
| 252 |
+
ctx.refresh()
|
| 253 |
+
ctx.set_description(f"NFE: {steps_counter}")
|
| 254 |
+
|
| 255 |
+
if return_intermediates:
|
| 256 |
+
if step_size is None:
|
| 257 |
+
return torch.stack(res, dim=0)
|
| 258 |
+
else:
|
| 259 |
+
return torch.stack(res, dim=0)[order]
|
| 260 |
+
else:
|
| 261 |
+
return x_t
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@torch.no_grad()
|
| 265 |
+
def multi_guidance_sample(
|
| 266 |
+
self,
|
| 267 |
+
args,
|
| 268 |
+
x_init: Tensor,
|
| 269 |
+
step_size: Optional[float],
|
| 270 |
+
div_free: Union[float, Callable[[float], float]] = 0.0,
|
| 271 |
+
dtype_categorical: torch.dtype = torch.float32,
|
| 272 |
+
time_grid: Tensor = torch.tensor([0.0, 1.0]),
|
| 273 |
+
return_intermediates: bool = False,
|
| 274 |
+
verbose: bool = False,
|
| 275 |
+
score_models: list = None,
|
| 276 |
+
num_objectives: int = 1,
|
| 277 |
+
weights: list = None,
|
| 278 |
+
**model_extras,
|
| 279 |
+
) -> Tensor:
|
| 280 |
+
|
| 281 |
+
# score_list_0 = []
|
| 282 |
+
# score_list_1 = []
|
| 283 |
+
# score_list_2 = []
|
| 284 |
+
# score_list_3 = []
|
| 285 |
+
# score_list_4 = []
|
| 286 |
+
# score_list_5 = []
|
| 287 |
+
|
| 288 |
+
import pdb
|
| 289 |
+
|
| 290 |
+
if not div_free == 0.0:
|
| 291 |
+
raise NotImplementedError
|
| 292 |
+
|
| 293 |
+
# Initialize the current state `x_t` with the initial state `X_0`.
|
| 294 |
+
time_grid = time_grid.to(device=x_init.device)
|
| 295 |
+
|
| 296 |
+
if step_size is None:
|
| 297 |
+
# If step_size is None then set the t discretization to time_grid.
|
| 298 |
+
t_discretization = time_grid
|
| 299 |
+
n_steps = len(time_grid) - 1
|
| 300 |
+
else:
|
| 301 |
+
# If step_size is float then t discretization is uniform with step size set by step_size.
|
| 302 |
+
t_init = time_grid[0].item()
|
| 303 |
+
t_final = time_grid[-1].item()
|
| 304 |
+
assert (
|
| 305 |
+
t_final - t_init
|
| 306 |
+
) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
|
| 307 |
+
|
| 308 |
+
n_steps = ceil((t_final - t_init) / step_size)
|
| 309 |
+
t_discretization = torch.tensor(
|
| 310 |
+
[t_init + step_size * i for i in range(n_steps)] + [t_final],
|
| 311 |
+
device=x_init.device,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if return_intermediates:
|
| 315 |
+
# get order of intermediate steps:
|
| 316 |
+
order = torch.argsort(time_grid)
|
| 317 |
+
# Compute intermediate steps to return via nearest points in t_discretization to time_grid.
|
| 318 |
+
time_grid = get_nearest_times(
|
| 319 |
+
time_grid=time_grid, t_discretization=t_discretization
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
x_t = x_init.clone()
|
| 323 |
+
steps_counter = 0
|
| 324 |
+
res = []
|
| 325 |
+
|
| 326 |
+
if return_intermediates:
|
| 327 |
+
res = [x_init.clone()]
|
| 328 |
+
|
| 329 |
+
if verbose:
|
| 330 |
+
if not TQDM_AVAILABLE:
|
| 331 |
+
raise ImportError(
|
| 332 |
+
"tqdm is required for verbose mode. Please install it."
|
| 333 |
+
)
|
| 334 |
+
ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
|
| 335 |
+
else:
|
| 336 |
+
ctx = nullcontext()
|
| 337 |
+
|
| 338 |
+
# Randomly sample a weight vector
|
| 339 |
+
if weights is not None:
|
| 340 |
+
w = torch.tensor(weights).to(device=x_init.device)
|
| 341 |
+
else:
|
| 342 |
+
w, _ = select_random_weight_vector(num_objectives, args.num_div)
|
| 343 |
+
# w = torch.tensor([0.2, 0.7, 0.05, 0.05]).to(x_t.device)
|
| 344 |
+
w = w.to(device=x_init.device)
|
| 345 |
+
print(f"Weight Vector: {w}")
|
| 346 |
+
Phi = args.Phi_init
|
| 347 |
+
ema_r_t = None
|
| 348 |
+
|
| 349 |
+
with ctx:
|
| 350 |
+
for i in range(n_steps):
|
| 351 |
+
t = t_discretization[i : i + 1]
|
| 352 |
+
h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1]
|
| 353 |
+
|
| 354 |
+
p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras)
|
| 355 |
+
x_1 = categorical(p_1t.to(dtype=dtype_categorical))
|
| 356 |
+
|
| 357 |
+
# Checks if final step
|
| 358 |
+
if i != n_steps - 1:
|
| 359 |
+
# Compute u_t(y,x)
|
| 360 |
+
scheduler_output = self.path.scheduler(t=t)
|
| 361 |
+
k_t = scheduler_output.alpha_t
|
| 362 |
+
d_k_t = scheduler_output.d_alpha_t
|
| 363 |
+
u_t = d_k_t / (1 - k_t) * p_1t
|
| 364 |
+
|
| 365 |
+
guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, args)
|
| 366 |
+
|
| 367 |
+
best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
|
| 368 |
+
|
| 369 |
+
# best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
|
| 370 |
+
|
| 371 |
+
# best_candidate = get_best_candidate(improvement_values, cand_tokens, delta_S)
|
| 372 |
+
|
| 373 |
+
x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
steps_counter += 1
|
| 377 |
+
t = t + h
|
| 378 |
+
|
| 379 |
+
scores = []
|
| 380 |
+
for i, s in enumerate(score_models):
|
| 381 |
+
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
|
| 382 |
+
if 't' in sig.parameters:
|
| 383 |
+
candidate_scores = s(x_t, 1)
|
| 384 |
+
else:
|
| 385 |
+
candidate_scores = s(x_t)
|
| 386 |
+
|
| 387 |
+
if isinstance(candidate_scores, tuple):
|
| 388 |
+
for score in candidate_scores:
|
| 389 |
+
scores.append(score.item())
|
| 390 |
+
else:
|
| 391 |
+
scores.append(candidate_scores.item())
|
| 392 |
+
print(scores)
|
| 393 |
+
|
| 394 |
+
# print(f"Score {i}: {[round(s.item(), 4) for s in candidate_scores]}")
|
| 395 |
+
# if i == 0:
|
| 396 |
+
# score_list_0.append(round(candidate_scores[0].item(), 2))
|
| 397 |
+
# # score_list_0.append(round(1-candidate_scores.item(), 2))
|
| 398 |
+
# # score_list_1.append(round(candidate_scores[1].item(), 2))
|
| 399 |
+
# if i == 1:
|
| 400 |
+
# score_list_1.append(round(candidate_scores.item(), 2))
|
| 401 |
+
# # score_list_2.append(round(candidate_scores.item(), 2))
|
| 402 |
+
# if i == 2:
|
| 403 |
+
# score_list_2.append(round(candidate_scores.item(), 2))
|
| 404 |
+
# if i == 3:
|
| 405 |
+
# score_list_3.append(round(candidate_scores.item(), 2))
|
| 406 |
+
# if i == 4:
|
| 407 |
+
# score_list_4.append(round(candidate_scores.item(), 2))
|
| 408 |
+
# if i == 5:
|
| 409 |
+
# score_list_5.append(round(candidate_scores.item(), 2))
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
if return_intermediates and (t in time_grid):
|
| 413 |
+
res.append(x_t.clone())
|
| 414 |
+
|
| 415 |
+
if verbose:
|
| 416 |
+
ctx.n = t.item()
|
| 417 |
+
ctx.refresh()
|
| 418 |
+
ctx.set_description(f"NFE: {steps_counter}")
|
| 419 |
+
|
| 420 |
+
# print(score_list)
|
| 421 |
+
if return_intermediates:
|
| 422 |
+
if step_size is None:
|
| 423 |
+
return torch.stack(res, dim=0)
|
| 424 |
+
else:
|
| 425 |
+
return torch.stack(res, dim=0)[order]
|
| 426 |
+
else:
|
| 427 |
+
# return x_t, score_list_0, score_list_1, score_list_2, score_list_3, score_list_4, score_list_5
|
| 428 |
+
return x_t
|
flow_matching/solver/ode_solver.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional, Sequence, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torchdiffeq import odeint
|
| 12 |
+
|
| 13 |
+
from flow_matching.solver.solver import Solver
|
| 14 |
+
from flow_matching.utils import gradient, ModelWrapper
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ODESolver(Solver):
|
| 18 |
+
"""A class to solve ordinary differential equations (ODEs) using a specified velocity model.
|
| 19 |
+
|
| 20 |
+
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)`
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.velocity_model = velocity_model
|
| 29 |
+
|
| 30 |
+
def sample(
|
| 31 |
+
self,
|
| 32 |
+
x_init: Tensor,
|
| 33 |
+
step_size: Optional[float],
|
| 34 |
+
method: str = "euler",
|
| 35 |
+
atol: float = 1e-5,
|
| 36 |
+
rtol: float = 1e-5,
|
| 37 |
+
time_grid: Tensor = torch.tensor([0.0, 1.0]),
|
| 38 |
+
return_intermediates: bool = False,
|
| 39 |
+
enable_grad: bool = False,
|
| 40 |
+
**model_extras,
|
| 41 |
+
) -> Union[Tensor, Sequence[Tensor]]:
|
| 42 |
+
r"""Solve the ODE with the velocity field.
|
| 43 |
+
|
| 44 |
+
Example:
|
| 45 |
+
|
| 46 |
+
.. code-block:: python
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
from flow_matching.utils import ModelWrapper
|
| 50 |
+
from flow_matching.solver import ODESolver
|
| 51 |
+
|
| 52 |
+
class DummyModel(ModelWrapper):
|
| 53 |
+
def __init__(self):
|
| 54 |
+
super().__init__(None)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
|
| 57 |
+
return torch.ones_like(x) * 3.0 * t**2
|
| 58 |
+
|
| 59 |
+
velocity_model = DummyModel()
|
| 60 |
+
solver = ODESolver(velocity_model=velocity_model)
|
| 61 |
+
x_init = torch.tensor([0.0, 0.0])
|
| 62 |
+
step_size = 0.001
|
| 63 |
+
time_grid = torch.tensor([0.0, 1.0])
|
| 64 |
+
|
| 65 |
+
result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...].
|
| 69 |
+
step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
|
| 70 |
+
method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
|
| 71 |
+
atol (float): Absolute tolerance, used for adaptive step solvers.
|
| 72 |
+
rtol (float): Relative tolerance, used for adaptive step solvers.
|
| 73 |
+
time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
|
| 74 |
+
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
|
| 75 |
+
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
|
| 76 |
+
**model_extras: Additional input for the model.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
time_grid = time_grid.to(x_init.device)
|
| 83 |
+
|
| 84 |
+
def ode_func(t, x):
|
| 85 |
+
return self.velocity_model(x=x, t=t, **model_extras)
|
| 86 |
+
|
| 87 |
+
ode_opts = {"step_size": step_size} if step_size is not None else {}
|
| 88 |
+
|
| 89 |
+
with torch.set_grad_enabled(enable_grad):
|
| 90 |
+
# Approximate ODE solution with numerical ODE solver
|
| 91 |
+
sol = odeint(
|
| 92 |
+
ode_func,
|
| 93 |
+
x_init,
|
| 94 |
+
time_grid,
|
| 95 |
+
method=method,
|
| 96 |
+
options=ode_opts,
|
| 97 |
+
atol=atol,
|
| 98 |
+
rtol=rtol,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if return_intermediates:
|
| 102 |
+
return sol
|
| 103 |
+
else:
|
| 104 |
+
return sol[-1]
|
| 105 |
+
|
| 106 |
+
def compute_likelihood(
|
| 107 |
+
self,
|
| 108 |
+
x_1: Tensor,
|
| 109 |
+
log_p0: Callable[[Tensor], Tensor],
|
| 110 |
+
step_size: Optional[float],
|
| 111 |
+
method: str = "euler",
|
| 112 |
+
atol: float = 1e-5,
|
| 113 |
+
rtol: float = 1e-5,
|
| 114 |
+
time_grid: Tensor = torch.tensor([1.0, 0.0]),
|
| 115 |
+
return_intermediates: bool = False,
|
| 116 |
+
exact_divergence: bool = False,
|
| 117 |
+
enable_grad: bool = False,
|
| 118 |
+
**model_extras,
|
| 119 |
+
) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
|
| 120 |
+
r"""Solve for log likelihood given a target sample at :math:`t=0`.
|
| 121 |
+
|
| 122 |
+
Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x.
|
| 123 |
+
The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`).
|
| 127 |
+
log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution.
|
| 128 |
+
step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
|
| 129 |
+
method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
|
| 130 |
+
atol (float): Absolute tolerance, used for adaptive step solvers.
|
| 131 |
+
rtol (float): Relative tolerance, used for adaptive step solvers.
|
| 132 |
+
time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
|
| 133 |
+
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
|
| 134 |
+
exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
|
| 135 |
+
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
|
| 136 |
+
**model_extras: Additional input for the model.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1.
|
| 140 |
+
"""
|
| 141 |
+
assert (
|
| 142 |
+
time_grid[0] == 1.0 and time_grid[-1] == 0.0
|
| 143 |
+
), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}"
|
| 144 |
+
|
| 145 |
+
# Fix the random projection for the Hutchinson divergence estimator
|
| 146 |
+
if not exact_divergence:
|
| 147 |
+
z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0
|
| 148 |
+
|
| 149 |
+
def ode_func(x, t):
|
| 150 |
+
return self.velocity_model(x=x, t=t, **model_extras)
|
| 151 |
+
|
| 152 |
+
def dynamics_func(t, states):
|
| 153 |
+
xt = states[0]
|
| 154 |
+
with torch.set_grad_enabled(True):
|
| 155 |
+
xt.requires_grad_()
|
| 156 |
+
ut = ode_func(xt, t)
|
| 157 |
+
|
| 158 |
+
if exact_divergence:
|
| 159 |
+
# Compute exact divergence
|
| 160 |
+
div = 0
|
| 161 |
+
for i in range(ut.flatten(1).shape[1]):
|
| 162 |
+
div += gradient(ut[:, i], xt, create_graph=True)[:, i]
|
| 163 |
+
else:
|
| 164 |
+
# Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
|
| 165 |
+
ut_dot_z = torch.einsum(
|
| 166 |
+
"ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1)
|
| 167 |
+
)
|
| 168 |
+
grad_ut_dot_z = gradient(ut_dot_z, xt)
|
| 169 |
+
div = torch.einsum(
|
| 170 |
+
"ij,ij->i",
|
| 171 |
+
grad_ut_dot_z.flatten(start_dim=1),
|
| 172 |
+
z.flatten(start_dim=1),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return ut.detach(), div.detach()
|
| 176 |
+
|
| 177 |
+
y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
|
| 178 |
+
ode_opts = {"step_size": step_size} if step_size is not None else {}
|
| 179 |
+
|
| 180 |
+
with torch.set_grad_enabled(enable_grad):
|
| 181 |
+
sol, log_det = odeint(
|
| 182 |
+
dynamics_func,
|
| 183 |
+
y_init,
|
| 184 |
+
time_grid,
|
| 185 |
+
method=method,
|
| 186 |
+
options=ode_opts,
|
| 187 |
+
atol=atol,
|
| 188 |
+
rtol=rtol,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
x_source = sol[-1]
|
| 192 |
+
source_log_p = log_p0(x_source)
|
| 193 |
+
|
| 194 |
+
if return_intermediates:
|
| 195 |
+
return sol, source_log_p + log_det[-1]
|
| 196 |
+
else:
|
| 197 |
+
return sol[-1], source_log_p + log_det[-1]
|
flow_matching/solver/riemannian_ode_solver.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Callable
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
from flow_matching.solver.solver import Solver
|
| 14 |
+
from flow_matching.utils import ModelWrapper
|
| 15 |
+
from flow_matching.utils.manifolds import geodesic, Manifold
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
TQDM_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
TQDM_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RiemannianODESolver(Solver):
|
| 26 |
+
r"""Riemannian ODE solver
|
| 27 |
+
Initialize the ``RiemannianODESolver``.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
manifold (Manifold): the manifold to solve on.
|
| 31 |
+
velocity_model (ModelWrapper): a velocity field model receiving :math:`(x,t)`
|
| 32 |
+
and returning :math:`u_t(x)` which is assumed to lie on the tangent plane at `x`.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, manifold: Manifold, velocity_model: ModelWrapper):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.manifold = manifold
|
| 38 |
+
self.velocity_model = velocity_model
|
| 39 |
+
|
| 40 |
+
def sample(
|
| 41 |
+
self,
|
| 42 |
+
x_init: Tensor,
|
| 43 |
+
step_size: float,
|
| 44 |
+
projx: bool = True,
|
| 45 |
+
proju: bool = True,
|
| 46 |
+
method: str = "euler",
|
| 47 |
+
time_grid: Tensor = torch.tensor([0.0, 1.0]),
|
| 48 |
+
return_intermediates: bool = False,
|
| 49 |
+
verbose: bool = False,
|
| 50 |
+
enable_grad: bool = False,
|
| 51 |
+
**model_extras,
|
| 52 |
+
) -> Tensor:
|
| 53 |
+
r"""Solve the ODE with the `velocity_field` on the manifold.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`).
|
| 57 |
+
step_size (float): The step size.
|
| 58 |
+
projx (bool): Whether to project the point onto the manifold at each step. Defaults to True.
|
| 59 |
+
proju (bool): Whether to project the vector field onto the tangent plane at each step. Defaults to True.
|
| 60 |
+
method (str): One of ["euler", "midpoint", "rk4"]. Defaults to "euler".
|
| 61 |
+
time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
|
| 62 |
+
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
|
| 63 |
+
verbose (bool, optional): Whether to print progress bars. Defaults to False.
|
| 64 |
+
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
|
| 65 |
+
**model_extras: Additional input for the model.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`.
|
| 69 |
+
|
| 70 |
+
Raises:
|
| 71 |
+
ImportError: To run in verbose mode, tqdm must be installed.
|
| 72 |
+
"""
|
| 73 |
+
step_fns = {
|
| 74 |
+
"euler": _euler_step,
|
| 75 |
+
"midpoint": _midpoint_step,
|
| 76 |
+
"rk4": _rk4_step,
|
| 77 |
+
}
|
| 78 |
+
assert method in step_fns.keys(), f"Unknown method {method}"
|
| 79 |
+
step_fn = step_fns[method]
|
| 80 |
+
|
| 81 |
+
def velocity_func(x, t):
|
| 82 |
+
return self.velocity_model(x=x, t=t, **model_extras)
|
| 83 |
+
|
| 84 |
+
# --- Factor this out.
|
| 85 |
+
time_grid = torch.sort(time_grid.to(device=x_init.device)).values
|
| 86 |
+
|
| 87 |
+
if step_size is None:
|
| 88 |
+
# If step_size is None then set the t discretization to time_grid.
|
| 89 |
+
t_discretization = time_grid
|
| 90 |
+
n_steps = len(time_grid) - 1
|
| 91 |
+
else:
|
| 92 |
+
# If step_size is float then t discretization is uniform with step size set by step_size.
|
| 93 |
+
t_init = time_grid[0].item()
|
| 94 |
+
t_final = time_grid[-1].item()
|
| 95 |
+
assert (
|
| 96 |
+
t_final - t_init
|
| 97 |
+
) > step_size, f"Time interval [min(time_grid), max(time_grid)] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
|
| 98 |
+
|
| 99 |
+
n_steps = math.ceil((t_final - t_init) / step_size)
|
| 100 |
+
t_discretization = torch.tensor(
|
| 101 |
+
[step_size * i for i in range(n_steps)] + [t_final],
|
| 102 |
+
device=x_init.device,
|
| 103 |
+
)
|
| 104 |
+
# ---
|
| 105 |
+
t0s = t_discretization[:-1]
|
| 106 |
+
|
| 107 |
+
if verbose:
|
| 108 |
+
if not TQDM_AVAILABLE:
|
| 109 |
+
raise ImportError(
|
| 110 |
+
"tqdm is required for verbose mode. Please install it."
|
| 111 |
+
)
|
| 112 |
+
t0s = tqdm(t0s)
|
| 113 |
+
|
| 114 |
+
if return_intermediates:
|
| 115 |
+
xts = []
|
| 116 |
+
i_ret = 0
|
| 117 |
+
|
| 118 |
+
with torch.set_grad_enabled(enable_grad):
|
| 119 |
+
xt = x_init
|
| 120 |
+
for t0, t1 in zip(t0s, t_discretization[1:]):
|
| 121 |
+
dt = t1 - t0
|
| 122 |
+
xt_next = step_fn(
|
| 123 |
+
velocity_func,
|
| 124 |
+
xt,
|
| 125 |
+
t0,
|
| 126 |
+
dt,
|
| 127 |
+
manifold=self.manifold,
|
| 128 |
+
projx=projx,
|
| 129 |
+
proju=proju,
|
| 130 |
+
)
|
| 131 |
+
if return_intermediates:
|
| 132 |
+
while (
|
| 133 |
+
i_ret < len(time_grid)
|
| 134 |
+
and t0 <= time_grid[i_ret]
|
| 135 |
+
and time_grid[i_ret] <= t1
|
| 136 |
+
):
|
| 137 |
+
xts.append(
|
| 138 |
+
interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret])
|
| 139 |
+
)
|
| 140 |
+
i_ret += 1
|
| 141 |
+
xt = xt_next
|
| 142 |
+
|
| 143 |
+
if return_intermediates:
|
| 144 |
+
return torch.stack(xts, dim=0)
|
| 145 |
+
else:
|
| 146 |
+
return xt
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def interp(manifold, xt, xt_next, t, t_next, t_ret):
|
| 150 |
+
return geodesic(manifold, xt, xt_next)(
|
| 151 |
+
(t_ret - t) / (t_next - t).reshape(1)
|
| 152 |
+
).reshape_as(xt)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _euler_step(
|
| 156 |
+
velocity_model: Callable,
|
| 157 |
+
xt: Tensor,
|
| 158 |
+
t0: Tensor,
|
| 159 |
+
dt: Tensor,
|
| 160 |
+
manifold: Manifold,
|
| 161 |
+
projx: bool = True,
|
| 162 |
+
proju: bool = True,
|
| 163 |
+
) -> Tensor:
|
| 164 |
+
r"""Perform an Euler step on a manifold.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
velocity_model (Callable): the velocity model
|
| 168 |
+
xt (Tensor): tensor containing the state at time t0
|
| 169 |
+
t0 (Tensor): the time at which this step is taken
|
| 170 |
+
dt (Tensor): the step size
|
| 171 |
+
manifold (Manifold): a manifold object
|
| 172 |
+
projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
|
| 173 |
+
proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Tensor: tensor containing the state after the step
|
| 177 |
+
"""
|
| 178 |
+
velocity_fn = lambda x, t: (
|
| 179 |
+
manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
|
| 180 |
+
)
|
| 181 |
+
projx_fn = lambda x: manifold.projx(x) if projx else x
|
| 182 |
+
|
| 183 |
+
vt = velocity_fn(xt, t0)
|
| 184 |
+
|
| 185 |
+
xt = xt + dt * vt
|
| 186 |
+
|
| 187 |
+
return projx_fn(xt)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _midpoint_step(
|
| 191 |
+
velocity_model: Callable,
|
| 192 |
+
xt: Tensor,
|
| 193 |
+
t0: Tensor,
|
| 194 |
+
dt: Tensor,
|
| 195 |
+
manifold: Manifold,
|
| 196 |
+
projx: bool = True,
|
| 197 |
+
proju: bool = True,
|
| 198 |
+
) -> Tensor:
|
| 199 |
+
r"""Perform a midpoint step on a manifold.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
velocity_model (Callable): the velocity model
|
| 203 |
+
xt (Tensor): tensor containing the state at time t0
|
| 204 |
+
t0 (Tensor): the time at which this step is taken
|
| 205 |
+
dt (Tensor): the step size
|
| 206 |
+
manifold (Manifold): a manifold object
|
| 207 |
+
projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
|
| 208 |
+
proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Tensor: tensor containing the state after the step
|
| 212 |
+
"""
|
| 213 |
+
velocity_fn = lambda x, t: (
|
| 214 |
+
manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
|
| 215 |
+
)
|
| 216 |
+
projx_fn = lambda x: manifold.projx(x) if projx else x
|
| 217 |
+
|
| 218 |
+
half_dt = 0.5 * dt
|
| 219 |
+
vt = velocity_fn(xt, t0)
|
| 220 |
+
x_mid = xt + half_dt * vt
|
| 221 |
+
x_mid = projx_fn(x_mid)
|
| 222 |
+
|
| 223 |
+
xt = xt + dt * velocity_fn(x_mid, t0 + half_dt)
|
| 224 |
+
|
| 225 |
+
return projx_fn(xt)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _rk4_step(
|
| 229 |
+
velocity_model: Callable,
|
| 230 |
+
xt: Tensor,
|
| 231 |
+
t0: Tensor,
|
| 232 |
+
dt: Tensor,
|
| 233 |
+
manifold: Manifold,
|
| 234 |
+
projx: bool = True,
|
| 235 |
+
proju: bool = True,
|
| 236 |
+
) -> Tensor:
|
| 237 |
+
r"""Perform an RK4 step on a manifold.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
velocity_model (Callable): the velocity model
|
| 241 |
+
xt (Tensor): tensor containing the state at time t0
|
| 242 |
+
t0 (Tensor): the time at which this step is taken
|
| 243 |
+
dt (Tensor): the step size
|
| 244 |
+
manifold (Manifold): a manifold object
|
| 245 |
+
projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
|
| 246 |
+
proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Tensor: tensor containing the state after the step
|
| 250 |
+
"""
|
| 251 |
+
velocity_fn = lambda x, t: (
|
| 252 |
+
manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
|
| 253 |
+
)
|
| 254 |
+
projx_fn = lambda x: manifold.projx(x) if projx else x
|
| 255 |
+
|
| 256 |
+
k1 = velocity_fn(xt, t0)
|
| 257 |
+
k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3)
|
| 258 |
+
k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3)
|
| 259 |
+
k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt)
|
| 260 |
+
|
| 261 |
+
return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125)
|
flow_matching/solver/solver.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
from torch import nn, Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Solver(ABC, nn.Module):
|
| 13 |
+
"""Abstract base class for solvers."""
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def sample(self, x_0: Tensor = None) -> Tensor:
|
| 17 |
+
...
|
flow_matching/solver/utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor:
|
| 12 |
+
distances = torch.cdist(
|
| 13 |
+
time_grid.unsqueeze(1),
|
| 14 |
+
t_discretization.unsqueeze(1),
|
| 15 |
+
compute_mode="donot_use_mm_for_euclid_dist",
|
| 16 |
+
)
|
| 17 |
+
nearest_indices = distances.argmin(dim=1)
|
| 18 |
+
|
| 19 |
+
return t_discretization[nearest_indices]
|
flow_matching/utils/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .categorical_sampler import categorical
|
| 8 |
+
from .model_wrapper import ModelWrapper
|
| 9 |
+
from .utils import expand_tensor_like, gradient, unsqueeze_to_match
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"unsqueeze_to_match",
|
| 13 |
+
"expand_tensor_like",
|
| 14 |
+
"gradient",
|
| 15 |
+
"categorical",
|
| 16 |
+
"ModelWrapper",
|
| 17 |
+
]
|
flow_matching/utils/categorical_sampler.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def categorical(probs: Tensor) -> Tensor:
|
| 12 |
+
r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
probs (Tensor): probabilities.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Tensor: Samples.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view(
|
| 22 |
+
*probs.shape[:-1]
|
| 23 |
+
)
|
flow_matching/utils/manifolds/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .manifold import Euclidean, Manifold
|
| 8 |
+
from .sphere import Sphere
|
| 9 |
+
from .torus import FlatTorus
|
| 10 |
+
from .utils import geodesic
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"Euclidean",
|
| 14 |
+
"Manifold",
|
| 15 |
+
"Sphere",
|
| 16 |
+
"FlatTorus",
|
| 17 |
+
"geodesic",
|
| 18 |
+
]
|
flow_matching/utils/manifolds/manifold.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import abc
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Manifold(nn.Module, metaclass=abc.ABCMeta):
|
| 14 |
+
"""A manifold class that contains projection operations and logarithm and exponential maps."""
|
| 15 |
+
|
| 16 |
+
@abc.abstractmethod
|
| 17 |
+
def expmap(self, x: Tensor, u: Tensor) -> Tensor:
|
| 18 |
+
r"""Computes exponential map :math:`\exp_x(u)`.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
x (Tensor): point on the manifold
|
| 22 |
+
u (Tensor): tangent vector at point :math:`x`
|
| 23 |
+
|
| 24 |
+
Raises:
|
| 25 |
+
NotImplementedError: if not implemented
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Tensor: transported point
|
| 29 |
+
"""
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
|
| 32 |
+
@abc.abstractmethod
|
| 33 |
+
def logmap(self, x: Tensor, y: Tensor) -> Tensor:
|
| 34 |
+
r"""Computes logarithmic map :math:`\log_x(y)`.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x (Tensor): point on the manifold
|
| 38 |
+
y (Tensor): point on the manifold
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
NotImplementedError: if not implemented
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tensor: tangent vector at point :math:`x`
|
| 45 |
+
"""
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
+
@abc.abstractmethod
|
| 49 |
+
def projx(self, x: Tensor) -> Tensor:
|
| 50 |
+
"""Project point :math:`x` on the manifold.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
x (Tensor): point to be projected
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
NotImplementedError: if not implemented
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Tensor: projected point on the manifold
|
| 60 |
+
"""
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
@abc.abstractmethod
|
| 64 |
+
def proju(self, x: Tensor, u: Tensor) -> Tensor:
|
| 65 |
+
"""Project vector :math:`u` on a tangent space for :math:`x`.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
x (Tensor): point on the manifold
|
| 69 |
+
u (Tensor): vector to be projected
|
| 70 |
+
|
| 71 |
+
Raises:
|
| 72 |
+
NotImplementedError: if not implemented
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tensor: projected tangent vector
|
| 76 |
+
"""
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Euclidean(Manifold):
|
| 81 |
+
"""The Euclidean manifold."""
|
| 82 |
+
|
| 83 |
+
def expmap(self, x: Tensor, u: Tensor) -> Tensor:
|
| 84 |
+
return x + u
|
| 85 |
+
|
| 86 |
+
def logmap(self, x: Tensor, y: Tensor) -> Tensor:
|
| 87 |
+
return y - x
|
| 88 |
+
|
| 89 |
+
def projx(self, x: Tensor) -> Tensor:
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
def proju(self, x: Tensor, u: Tensor) -> Tensor:
|
| 93 |
+
return u
|
flow_matching/utils/manifolds/sphere.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from flow_matching.utils.manifolds import Manifold
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Sphere(Manifold):
|
| 14 |
+
"""Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres."""
|
| 15 |
+
|
| 16 |
+
EPS = {torch.float32: 1e-4, torch.float64: 1e-7}
|
| 17 |
+
|
| 18 |
+
def expmap(self, x: Tensor, u: Tensor) -> Tensor:
|
| 19 |
+
norm_u = u.norm(dim=-1, keepdim=True)
|
| 20 |
+
exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u
|
| 21 |
+
retr = self.projx(x + u)
|
| 22 |
+
cond = norm_u > self.EPS[norm_u.dtype]
|
| 23 |
+
|
| 24 |
+
return torch.where(cond, exp, retr)
|
| 25 |
+
|
| 26 |
+
def logmap(self, x: Tensor, y: Tensor) -> Tensor:
|
| 27 |
+
u = self.proju(x, y - x)
|
| 28 |
+
dist = self.dist(x, y, keepdim=True)
|
| 29 |
+
cond = dist.gt(self.EPS[x.dtype])
|
| 30 |
+
result = torch.where(
|
| 31 |
+
cond,
|
| 32 |
+
u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]),
|
| 33 |
+
u,
|
| 34 |
+
)
|
| 35 |
+
return result
|
| 36 |
+
|
| 37 |
+
def projx(self, x: Tensor) -> Tensor:
|
| 38 |
+
return x / x.norm(dim=-1, keepdim=True)
|
| 39 |
+
|
| 40 |
+
def proju(self, x: Tensor, u: Tensor) -> Tensor:
|
| 41 |
+
return u - (x * u).sum(dim=-1, keepdim=True) * x
|
| 42 |
+
|
| 43 |
+
def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor:
|
| 44 |
+
inner = (x * y).sum(-1, keepdim=keepdim)
|
| 45 |
+
return torch.acos(inner)
|
flow_matching/utils/manifolds/torus.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
from flow_matching.utils.manifolds import Manifold
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FlatTorus(Manifold):
|
| 16 |
+
r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres."""
|
| 17 |
+
|
| 18 |
+
def expmap(self, x: Tensor, u: Tensor) -> Tensor:
|
| 19 |
+
return (x + u) % (2 * math.pi)
|
| 20 |
+
|
| 21 |
+
def logmap(self, x: Tensor, y: Tensor) -> Tensor:
|
| 22 |
+
return torch.atan2(torch.sin(y - x), torch.cos(y - x))
|
| 23 |
+
|
| 24 |
+
def projx(self, x: Tensor) -> Tensor:
|
| 25 |
+
return x % (2 * math.pi)
|
| 26 |
+
|
| 27 |
+
def proju(self, x: Tensor, u: Tensor) -> Tensor:
|
| 28 |
+
return u
|
flow_matching/utils/manifolds/utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
from flow_matching.utils.manifolds import Manifold
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def geodesic(
|
| 16 |
+
manifold: Manifold, start_point: Tensor, end_point: Tensor
|
| 17 |
+
) -> Callable[[Tensor], Tensor]:
|
| 18 |
+
"""Generate parameterized function for geodesic curve.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
manifold (Manifold): the manifold to compute geodesic on.
|
| 22 |
+
start_point (Tensor): point on the manifold at :math:`t=0`.
|
| 23 |
+
end_point (Tensor): point on the manifold at :math:`t=1`.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
shooting_tangent_vec = manifold.logmap(start_point, end_point)
|
| 30 |
+
|
| 31 |
+
def path(t: Tensor) -> Tensor:
|
| 32 |
+
"""Generate parameterized function for geodesic curve.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
t (Tensor): Times at which to compute points of the geodesics.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Tensor: geodesic path evaluated at time t.
|
| 39 |
+
"""
|
| 40 |
+
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
|
| 41 |
+
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
|
| 42 |
+
|
| 43 |
+
return points_at_time_t
|
| 44 |
+
|
| 45 |
+
return path
|
flow_matching/utils/model_wrapper.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from abc import ABC
|
| 8 |
+
|
| 9 |
+
from torch import nn, Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelWrapper(ABC, nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
This class is used to wrap around another model, adding custom forward pass logic.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model: nn.Module):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.model = model
|
| 20 |
+
|
| 21 |
+
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
|
| 22 |
+
r"""
|
| 23 |
+
This method defines how inputs should be passed through the wrapped model.
|
| 24 |
+
Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input,
|
| 25 |
+
along with any additional keyword arguments.
|
| 26 |
+
|
| 27 |
+
Optional things to do here:
|
| 28 |
+
- check that t is in the dimensions that the model is expecting.
|
| 29 |
+
- add a custom forward pass logic.
|
| 30 |
+
- call the wrapped model.
|
| 31 |
+
|
| 32 |
+
| given x, t
|
| 33 |
+
| returns the model output for input x at time t, with extra information `extra`.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
x (Tensor): input data to the model (batch_size, ...).
|
| 37 |
+
t (Tensor): time (batch_size).
|
| 38 |
+
**extras: additional information forwarded to the model, e.g., text condition.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Tensor: model output.
|
| 42 |
+
"""
|
| 43 |
+
return self.model(x=x, t=t, **extras)
|
flow_matching/utils/multi_guidance.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from flow_matching.utils import categorical
|
| 3 |
+
import math
|
| 4 |
+
import inspect
|
| 5 |
+
|
| 6 |
+
def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
|
| 7 |
+
def rec(n, H):
|
| 8 |
+
if n == 1:
|
| 9 |
+
return [[H]]
|
| 10 |
+
points = []
|
| 11 |
+
for i in range(H + 1):
|
| 12 |
+
for tail in rec(n - 1, H - i):
|
| 13 |
+
points.append([i] + tail)
|
| 14 |
+
return points
|
| 15 |
+
|
| 16 |
+
points = rec(num_obj, num_div)
|
| 17 |
+
weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div
|
| 18 |
+
return weight_vectors
|
| 19 |
+
|
| 20 |
+
def select_random_weight_vector(num_obj: int, num_div: int):
|
| 21 |
+
weight_vectors = generate_simplex_lattice_points(num_obj, num_div)
|
| 22 |
+
idx = torch.randint(0, weight_vectors.size(0), (1,)).item()
|
| 23 |
+
random_weight_vector = weight_vectors[idx]
|
| 24 |
+
return random_weight_vector, weight_vectors
|
| 25 |
+
|
| 26 |
+
def z_score_norm(tensor, eps=1e-8):
|
| 27 |
+
mean = tensor.mean(dim=-1, keepdim=True)
|
| 28 |
+
std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
|
| 29 |
+
return (tensor - mean) / std
|
| 30 |
+
|
| 31 |
+
def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
|
| 32 |
+
B, L, vocab_size = u_t.shape
|
| 33 |
+
device = x_t.device
|
| 34 |
+
guided_u_t = u_t.clone()
|
| 35 |
+
|
| 36 |
+
# 1. Randomly select one position per sequence.
|
| 37 |
+
pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
|
| 38 |
+
batch_idx = torch.arange(B, device=device)
|
| 39 |
+
current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
|
| 40 |
+
|
| 41 |
+
# 2. Build candidate tokens for each sequence and remove self-transition.
|
| 42 |
+
full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size)
|
| 43 |
+
mask = (full_cand_tokens != current_tokens.unsqueeze(1)) # (B, vocab_size)
|
| 44 |
+
# Now, cand_tokens contains only candidate tokens that differ from the current token.
|
| 45 |
+
cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 1) # (B, vocab_size-1)
|
| 46 |
+
|
| 47 |
+
# 3. Create candidate sequences by replacing the token at the selected position.
|
| 48 |
+
new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone()
|
| 49 |
+
new_x = new_x[mask].view(B, vocab_size - 1, L) # (B, vocab_size-1, L)
|
| 50 |
+
new_x[batch_idx, :, pos_indices] = cand_tokens
|
| 51 |
+
|
| 52 |
+
new_x_flat = new_x.view(B * (vocab_size - 1), L)
|
| 53 |
+
improvements_list = []
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
count = 0
|
| 56 |
+
for i, s in enumerate(s_models):
|
| 57 |
+
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
|
| 58 |
+
if 't' in sig.parameters:
|
| 59 |
+
candidate_scores = s(new_x_flat, t)
|
| 60 |
+
base_score = s(x_t, t)
|
| 61 |
+
else:
|
| 62 |
+
candidate_scores = s(new_x_flat)
|
| 63 |
+
base_score = s(x_t)
|
| 64 |
+
|
| 65 |
+
if isinstance(candidate_scores, tuple):
|
| 66 |
+
for k, score in enumerate(candidate_scores):
|
| 67 |
+
improvement = candidate_scores[k].view(B, vocab_size - 1) - base_score[k].unsqueeze(1)
|
| 68 |
+
improvement = improvement.float()
|
| 69 |
+
improvement *= importance[count]
|
| 70 |
+
improvements_list.append(improvement.unsqueeze(2))
|
| 71 |
+
count += 1
|
| 72 |
+
else:
|
| 73 |
+
improvement = candidate_scores.view(B, vocab_size - 1) - base_score.unsqueeze(1)
|
| 74 |
+
improvement = improvement.float()
|
| 75 |
+
improvement *= importance[count]
|
| 76 |
+
improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
|
| 77 |
+
count += 1
|
| 78 |
+
|
| 79 |
+
improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
|
| 80 |
+
if args.is_peptide:
|
| 81 |
+
improvement_values[:, :4, :] = -10 # Mask non-residue positions
|
| 82 |
+
|
| 83 |
+
# 5. Compute ranking scores I_n
|
| 84 |
+
ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
|
| 85 |
+
I_n = ranks / float(vocab_size - 1)
|
| 86 |
+
avg_I = I_n.mean(dim=2)
|
| 87 |
+
norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1)
|
| 88 |
+
|
| 89 |
+
# 6. Compute directional score D
|
| 90 |
+
D = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
|
| 91 |
+
norm_D = z_score_norm(D) # (B, vocab_size-1)
|
| 92 |
+
|
| 93 |
+
# 7. Combine the scores
|
| 94 |
+
delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1)
|
| 95 |
+
|
| 96 |
+
# 9. Update the guided velocities at the selected positions.
|
| 97 |
+
factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1)
|
| 98 |
+
factor = torch.clamp(factor, min=-100, max=100)
|
| 99 |
+
|
| 100 |
+
guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor
|
| 101 |
+
|
| 102 |
+
# 10. For the self-transition (current token) at the selected position,
|
| 103 |
+
# set its guided velocity to be the negative sum of the updated off-diagonals.
|
| 104 |
+
updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size)
|
| 105 |
+
sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens]
|
| 106 |
+
guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag
|
| 107 |
+
|
| 108 |
+
return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
|
| 109 |
+
|
| 110 |
+
def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
|
| 111 |
+
B, num_candidates, N = improvement_values.shape
|
| 112 |
+
device = improvement_values.device
|
| 113 |
+
eps = 1e-8
|
| 114 |
+
|
| 115 |
+
# Compute norms and angles.
|
| 116 |
+
imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates)
|
| 117 |
+
dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
|
| 118 |
+
w_norm = torch.norm(w) + eps
|
| 119 |
+
cos_angle = dot_product / (imp_norm * w_norm + eps)
|
| 120 |
+
cos_angle = cos_angle.clamp(-1.0, 1.0)
|
| 121 |
+
angles = torch.acos(cos_angle) # (B, num_candidates)
|
| 122 |
+
|
| 123 |
+
valid_mask = angles < math.pi / 2
|
| 124 |
+
accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates)
|
| 125 |
+
|
| 126 |
+
# Determine the best candidate for each sequence.
|
| 127 |
+
# We'll use a loop over batch items (batch size is typically moderate).
|
| 128 |
+
best_candidate = torch.empty(B, dtype=torch.long, device=device)
|
| 129 |
+
for i in range(B):
|
| 130 |
+
# For sequence i, consider only valid candidates.
|
| 131 |
+
if valid_mask[i].any():
|
| 132 |
+
# There is at least one candidate with α^i < π.
|
| 133 |
+
if accepted_mask[i].any():
|
| 134 |
+
# At least one candidate passes the hypercone: choose the one with max delta_S among accepted.
|
| 135 |
+
candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf')))
|
| 136 |
+
else:
|
| 137 |
+
# No candidate was accepted, but some are valid. Select best candidate among valid ones.
|
| 138 |
+
candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf')))
|
| 139 |
+
best_candidate[i] = cand_tokens[i, candidate_idx]
|
| 140 |
+
else:
|
| 141 |
+
# No candidate is valid (all α^i >= π) → self-transition.
|
| 142 |
+
best_candidate[i] = -1
|
| 143 |
+
|
| 144 |
+
# Compute rejection rate only over valid candidates.
|
| 145 |
+
rejection_rates = []
|
| 146 |
+
for i in range(B):
|
| 147 |
+
valid_candidates = valid_mask[i]
|
| 148 |
+
total_valid = valid_candidates.sum().item()
|
| 149 |
+
if total_valid > 0:
|
| 150 |
+
# Among valid candidates, count how many are rejected.
|
| 151 |
+
num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item()
|
| 152 |
+
rejection_rates.append(num_rejected / total_valid)
|
| 153 |
+
if len(rejection_rates) > 0:
|
| 154 |
+
r_t = sum(rejection_rates) / len(rejection_rates)
|
| 155 |
+
else:
|
| 156 |
+
# If no sequence has any valid candidate, set r_t to 0.
|
| 157 |
+
r_t = 0.0
|
| 158 |
+
|
| 159 |
+
if ema_r_t is None:
|
| 160 |
+
ema_r_t = args.tau
|
| 161 |
+
|
| 162 |
+
# Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch.
|
| 163 |
+
if valid_mask.any():
|
| 164 |
+
new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t
|
| 165 |
+
new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device))
|
| 166 |
+
new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item()
|
| 167 |
+
else:
|
| 168 |
+
new_ema_r_t = ema_r_t
|
| 169 |
+
new_Phi = Phi # No update if no valid candidate exists.
|
| 170 |
+
|
| 171 |
+
return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t
|
| 172 |
+
|
| 173 |
+
def get_best_candidate(improvement_values, cand_tokens, delta_S):
|
| 174 |
+
B, num_candidates, N = improvement_values.shape
|
| 175 |
+
device = improvement_values.device
|
| 176 |
+
best_candidate = torch.empty(B, dtype=torch.long, device=device)
|
| 177 |
+
|
| 178 |
+
for i in range(B):
|
| 179 |
+
candidate_idx = torch.argmax(delta_S[i])
|
| 180 |
+
best_candidate[i] = cand_tokens[i, candidate_idx]
|
| 181 |
+
|
| 182 |
+
return best_candidate
|
| 183 |
+
|
| 184 |
+
def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h):
|
| 185 |
+
B, L, V = guided_u_t.shape
|
| 186 |
+
device = x_t.device
|
| 187 |
+
u = torch.zeros_like(guided_u_t)
|
| 188 |
+
|
| 189 |
+
valid_mask = best_candidate != -1
|
| 190 |
+
if valid_mask.any():
|
| 191 |
+
valid_idx = torch.nonzero(valid_mask).squeeze(-1)
|
| 192 |
+
# For these sequences, update the velocity at the selected position and candidate token.
|
| 193 |
+
u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \
|
| 194 |
+
guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]]
|
| 195 |
+
|
| 196 |
+
# Compute intensity at the selected positions.
|
| 197 |
+
# For sequences with no valid candidate (i.e. self-transition), intensity remains zero.
|
| 198 |
+
intensity = torch.zeros(B, device=device)
|
| 199 |
+
if valid_mask.any():
|
| 200 |
+
intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1)
|
| 201 |
+
|
| 202 |
+
# According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)`
|
| 203 |
+
# However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling.
|
| 204 |
+
# To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`.
|
| 205 |
+
# So for faster sampling, we just use `1 - torch.exp(-1 * intensity)`
|
| 206 |
+
p_jump = 1 - torch.exp(-1 * intensity)
|
| 207 |
+
|
| 208 |
+
rand_val = torch.rand(B, device=device)
|
| 209 |
+
|
| 210 |
+
jump_decision = (rand_val < p_jump) & valid_mask
|
| 211 |
+
if True in jump_decision.tolist():
|
| 212 |
+
print("Jump!")
|
| 213 |
+
# For sequences where a jump is decided, update the token at pos_indices to best_candidate.
|
| 214 |
+
x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision]
|
| 215 |
+
|
| 216 |
+
return x_t
|
flow_matching/utils/multi_guidance_cnp.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from flow_matching.utils import categorical
|
| 3 |
+
import math
|
| 4 |
+
import inspect
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
|
| 8 |
+
def rec(n, H):
|
| 9 |
+
if n == 1:
|
| 10 |
+
return [[H]]
|
| 11 |
+
points = []
|
| 12 |
+
for i in range(H + 1):
|
| 13 |
+
for tail in rec(n - 1, H - i):
|
| 14 |
+
points.append([i] + tail)
|
| 15 |
+
return points
|
| 16 |
+
|
| 17 |
+
points = rec(num_obj, num_div)
|
| 18 |
+
weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div
|
| 19 |
+
return weight_vectors
|
| 20 |
+
|
| 21 |
+
def select_random_weight_vector(num_obj: int, num_div: int):
|
| 22 |
+
weight_vectors = generate_simplex_lattice_points(num_obj, num_div)
|
| 23 |
+
idx = torch.randint(0, weight_vectors.size(0), (1,)).item()
|
| 24 |
+
random_weight_vector = weight_vectors[idx]
|
| 25 |
+
return random_weight_vector, weight_vectors
|
| 26 |
+
|
| 27 |
+
def z_score_norm(tensor, eps=1e-8):
|
| 28 |
+
mean = tensor.mean(dim=-1, keepdim=True)
|
| 29 |
+
std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
|
| 30 |
+
return (tensor - mean) / std
|
| 31 |
+
|
| 32 |
+
def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
|
| 33 |
+
B, L, vocab_size = u_t.shape
|
| 34 |
+
device = x_t.device
|
| 35 |
+
guided_u_t = u_t.clone()
|
| 36 |
+
|
| 37 |
+
# 1. Randomly select one position per sequence.
|
| 38 |
+
# pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
|
| 39 |
+
pos_indices = torch.tensor([random.choice([i for i in range(1, L-2) if i != 6])]).to(x_t.device)
|
| 40 |
+
batch_idx = torch.arange(B, device=device)
|
| 41 |
+
current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
|
| 42 |
+
|
| 43 |
+
# 2. Build candidate tokens for each sequence and remove self-transition.
|
| 44 |
+
full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size)
|
| 45 |
+
mask = (full_cand_tokens != current_tokens.unsqueeze(1)) & (full_cand_tokens != 23) # (B, vocab_size)
|
| 46 |
+
# Now, cand_tokens contains only candidate tokens that differ from the current token.
|
| 47 |
+
cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 2) # (B, vocab_size-1)
|
| 48 |
+
|
| 49 |
+
# 3. Create candidate sequences by replacing the token at the selected position.
|
| 50 |
+
new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone()
|
| 51 |
+
new_x = new_x[mask].view(B, vocab_size - 2, L) # (B, vocab_size-1, L)
|
| 52 |
+
new_x[batch_idx, :, pos_indices] = cand_tokens
|
| 53 |
+
|
| 54 |
+
new_x_flat = new_x.view(B * (vocab_size - 2), L)
|
| 55 |
+
improvements_list = []
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
count = 0
|
| 58 |
+
for i, s in enumerate(s_models):
|
| 59 |
+
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
|
| 60 |
+
if 't' in sig.parameters:
|
| 61 |
+
candidate_scores = s(new_x_flat, t)
|
| 62 |
+
base_score = s(x_t, t)
|
| 63 |
+
else:
|
| 64 |
+
candidate_scores = s(new_x_flat)
|
| 65 |
+
base_score = s(x_t)
|
| 66 |
+
|
| 67 |
+
if isinstance(candidate_scores, tuple):
|
| 68 |
+
for k, score in enumerate(candidate_scores):
|
| 69 |
+
improvement = candidate_scores[k].view(B, vocab_size - 2) - base_score[k].unsqueeze(1)
|
| 70 |
+
improvement = improvement.float()
|
| 71 |
+
improvement *= importance[count]
|
| 72 |
+
improvements_list.append(improvement.unsqueeze(2))
|
| 73 |
+
count += 1
|
| 74 |
+
else:
|
| 75 |
+
improvement = candidate_scores.view(B, vocab_size - 2) - base_score.unsqueeze(1)
|
| 76 |
+
improvement = improvement.float()
|
| 77 |
+
improvement *= importance[count]
|
| 78 |
+
improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
|
| 79 |
+
count += 1
|
| 80 |
+
|
| 81 |
+
improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
|
| 82 |
+
if args.is_peptide:
|
| 83 |
+
improvement_values[:, :4, :] = -10 # Mask non-residue positions
|
| 84 |
+
|
| 85 |
+
# 5. Compute ranking scores I_n
|
| 86 |
+
ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
|
| 87 |
+
I_n = ranks / float(vocab_size - 2)
|
| 88 |
+
avg_I = I_n.mean(dim=2)
|
| 89 |
+
norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1)
|
| 90 |
+
|
| 91 |
+
# 6. Compute directional score D
|
| 92 |
+
D = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
|
| 93 |
+
norm_D = z_score_norm(D) # (B, vocab_size-1)
|
| 94 |
+
|
| 95 |
+
# 7. Combine the scores
|
| 96 |
+
delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1)
|
| 97 |
+
|
| 98 |
+
# 9. Update the guided velocities at the selected positions.
|
| 99 |
+
factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1)
|
| 100 |
+
factor = torch.clamp(factor, min=-100, max=100)
|
| 101 |
+
|
| 102 |
+
guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor
|
| 103 |
+
|
| 104 |
+
# 10. For the self-transition (current token) at the selected position,
|
| 105 |
+
# set its guided velocity to be the negative sum of the updated off-diagonals.
|
| 106 |
+
updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size)
|
| 107 |
+
sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens]
|
| 108 |
+
guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag
|
| 109 |
+
|
| 110 |
+
return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
|
| 111 |
+
|
| 112 |
+
def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
|
| 113 |
+
B, num_candidates, N = improvement_values.shape
|
| 114 |
+
device = improvement_values.device
|
| 115 |
+
eps = 1e-8
|
| 116 |
+
|
| 117 |
+
# Compute norms and angles.
|
| 118 |
+
imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates)
|
| 119 |
+
dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
|
| 120 |
+
w_norm = torch.norm(w) + eps
|
| 121 |
+
cos_angle = dot_product / (imp_norm * w_norm + eps)
|
| 122 |
+
cos_angle = cos_angle.clamp(-1.0, 1.0)
|
| 123 |
+
angles = torch.acos(cos_angle) # (B, num_candidates)
|
| 124 |
+
|
| 125 |
+
valid_mask = angles < math.pi / 2
|
| 126 |
+
accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates)
|
| 127 |
+
|
| 128 |
+
# Determine the best candidate for each sequence.
|
| 129 |
+
# We'll use a loop over batch items (batch size is typically moderate).
|
| 130 |
+
best_candidate = torch.empty(B, dtype=torch.long, device=device)
|
| 131 |
+
for i in range(B):
|
| 132 |
+
# For sequence i, consider only valid candidates.
|
| 133 |
+
if valid_mask[i].any():
|
| 134 |
+
# There is at least one candidate with α^i < π.
|
| 135 |
+
if accepted_mask[i].any():
|
| 136 |
+
# At least one candidate passes the hypercone: choose the one with max delta_S among accepted.
|
| 137 |
+
candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf')))
|
| 138 |
+
else:
|
| 139 |
+
# No candidate was accepted, but some are valid. Select best candidate among valid ones.
|
| 140 |
+
candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf')))
|
| 141 |
+
best_candidate[i] = cand_tokens[i, candidate_idx]
|
| 142 |
+
else:
|
| 143 |
+
# No candidate is valid (all α^i >= π) → self-transition.
|
| 144 |
+
best_candidate[i] = -1
|
| 145 |
+
|
| 146 |
+
# Compute rejection rate only over valid candidates.
|
| 147 |
+
rejection_rates = []
|
| 148 |
+
for i in range(B):
|
| 149 |
+
valid_candidates = valid_mask[i]
|
| 150 |
+
total_valid = valid_candidates.sum().item()
|
| 151 |
+
if total_valid > 0:
|
| 152 |
+
# Among valid candidates, count how many are rejected.
|
| 153 |
+
num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item()
|
| 154 |
+
rejection_rates.append(num_rejected / total_valid)
|
| 155 |
+
if len(rejection_rates) > 0:
|
| 156 |
+
r_t = sum(rejection_rates) / len(rejection_rates)
|
| 157 |
+
else:
|
| 158 |
+
# If no sequence has any valid candidate, set r_t to 0.
|
| 159 |
+
r_t = 0.0
|
| 160 |
+
|
| 161 |
+
if ema_r_t is None:
|
| 162 |
+
ema_r_t = args.tau
|
| 163 |
+
|
| 164 |
+
# Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch.
|
| 165 |
+
if valid_mask.any():
|
| 166 |
+
new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t
|
| 167 |
+
new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device))
|
| 168 |
+
new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item()
|
| 169 |
+
else:
|
| 170 |
+
new_ema_r_t = ema_r_t
|
| 171 |
+
new_Phi = Phi # No update if no valid candidate exists.
|
| 172 |
+
|
| 173 |
+
return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t
|
| 174 |
+
|
| 175 |
+
def get_best_candidate(improvement_values, cand_tokens, delta_S):
|
| 176 |
+
B, num_candidates, N = improvement_values.shape
|
| 177 |
+
device = improvement_values.device
|
| 178 |
+
best_candidate = torch.empty(B, dtype=torch.long, device=device)
|
| 179 |
+
|
| 180 |
+
for i in range(B):
|
| 181 |
+
candidate_idx = torch.argmax(delta_S[i])
|
| 182 |
+
best_candidate[i] = cand_tokens[i, candidate_idx]
|
| 183 |
+
|
| 184 |
+
return best_candidate
|
| 185 |
+
|
| 186 |
+
def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h):
|
| 187 |
+
B, L, V = guided_u_t.shape
|
| 188 |
+
device = x_t.device
|
| 189 |
+
u = torch.zeros_like(guided_u_t)
|
| 190 |
+
|
| 191 |
+
valid_mask = best_candidate != -1
|
| 192 |
+
if valid_mask.any():
|
| 193 |
+
valid_idx = torch.nonzero(valid_mask).squeeze(-1)
|
| 194 |
+
# For these sequences, update the velocity at the selected position and candidate token.
|
| 195 |
+
u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \
|
| 196 |
+
guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]]
|
| 197 |
+
|
| 198 |
+
# Compute intensity at the selected positions.
|
| 199 |
+
# For sequences with no valid candidate (i.e. self-transition), intensity remains zero.
|
| 200 |
+
intensity = torch.zeros(B, device=device)
|
| 201 |
+
if valid_mask.any():
|
| 202 |
+
intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1)
|
| 203 |
+
|
| 204 |
+
# According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)`
|
| 205 |
+
# However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling.
|
| 206 |
+
# To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`.
|
| 207 |
+
# So for faster sampling, we just use `1 - torch.exp(-1 * intensity)`
|
| 208 |
+
p_jump = 1 - torch.exp(-1 * intensity)
|
| 209 |
+
|
| 210 |
+
rand_val = torch.rand(B, device=device)
|
| 211 |
+
|
| 212 |
+
jump_decision = (rand_val < p_jump) & valid_mask
|
| 213 |
+
|
| 214 |
+
# For sequences where a jump is decided, update the token at pos_indices to best_candidate.
|
| 215 |
+
x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision]
|
| 216 |
+
|
| 217 |
+
return x_t
|
flow_matching/utils/utils.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Unsqueeze the source tensor to match the dimensionality of the target tensor.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
source (Tensor): The source tensor to be unsqueezed.
|
| 19 |
+
target (Tensor): The target tensor to match the dimensionality of.
|
| 20 |
+
how (str, optional): Whether to unsqueeze the source tensor at the beginning
|
| 21 |
+
("prefix") or end ("suffix"). Defaults to "suffix".
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tensor: The unsqueezed source tensor.
|
| 25 |
+
"""
|
| 26 |
+
assert (
|
| 27 |
+
how == "prefix" or how == "suffix"
|
| 28 |
+
), f"{how} is not supported, only 'prefix' and 'suffix' are supported."
|
| 29 |
+
|
| 30 |
+
dim_diff = target.dim() - source.dim()
|
| 31 |
+
|
| 32 |
+
for _ in range(dim_diff):
|
| 33 |
+
if how == "prefix":
|
| 34 |
+
source = source.unsqueeze(0)
|
| 35 |
+
elif how == "suffix":
|
| 36 |
+
source = source.unsqueeze(-1)
|
| 37 |
+
|
| 38 |
+
return source
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor:
|
| 42 |
+
"""`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`,
|
| 43 |
+
expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
input_tensor (Tensor): (batch_size,).
|
| 47 |
+
expand_to (Tensor): (batch_size, ...).
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Tensor: (batch_size, ...).
|
| 51 |
+
"""
|
| 52 |
+
assert input_tensor.ndim == 1, "Input tensor must be a 1d vector."
|
| 53 |
+
assert (
|
| 54 |
+
input_tensor.shape[0] == expand_to.shape[0]
|
| 55 |
+
), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}."
|
| 56 |
+
|
| 57 |
+
dim_diff = expand_to.ndim - input_tensor.ndim
|
| 58 |
+
|
| 59 |
+
t_expanded = input_tensor.clone()
|
| 60 |
+
t_expanded = t_expanded.reshape(-1, *([1] * dim_diff))
|
| 61 |
+
|
| 62 |
+
return t_expanded.expand_as(expand_to)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def gradient(
|
| 66 |
+
output: Tensor,
|
| 67 |
+
x: Tensor,
|
| 68 |
+
grad_outputs: Optional[Tensor] = None,
|
| 69 |
+
create_graph: bool = False,
|
| 70 |
+
) -> Tensor:
|
| 71 |
+
"""
|
| 72 |
+
Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
output (Tensor): [N, D] Output of the function.
|
| 76 |
+
x (Tensor): [N, d_1, d_2, ... ] input
|
| 77 |
+
grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`,
|
| 78 |
+
then will use a tensor of ones
|
| 79 |
+
create_graph (bool): If True, graph of the derivative will be constructed, allowing
|
| 80 |
+
to compute higher order derivative products. Defaults to False.
|
| 81 |
+
Returns:
|
| 82 |
+
Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
if grad_outputs is None:
|
| 86 |
+
grad_outputs = torch.ones_like(output).detach()
|
| 87 |
+
grad = torch.autograd.grad(
|
| 88 |
+
output, x, grad_outputs=grad_outputs, create_graph=create_graph
|
| 89 |
+
)[0]
|
| 90 |
+
return grad
|
models/classifier.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import copy
|
| 6 |
+
import pdb
|
| 7 |
+
|
| 8 |
+
class GaussianFourierProjection(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Gaussian random features for encoding time steps.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, embed_dim, scale=30.):
|
| 14 |
+
super().__init__()
|
| 15 |
+
# Randomly sample weights during initialization. These weights are fixed
|
| 16 |
+
# during optimization and are not trainable.
|
| 17 |
+
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
| 21 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
| 22 |
+
|
| 23 |
+
class Dense(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
A fully connected layer that reshapes outputs to feature maps.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, input_dim, output_dim):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.dense = nn.Linear(input_dim, output_dim)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return self.dense(x)[...]
|
| 34 |
+
|
| 35 |
+
class Swish(nn.Module):
|
| 36 |
+
def __init__(self):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return torch.sigmoid(x) * x
|
| 41 |
+
|
| 42 |
+
class CNNClassifier(nn.Module):
|
| 43 |
+
def __init__(self, args, alphabet_size, num_cls, classifier=False):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.alphabet_size = alphabet_size
|
| 46 |
+
self.args = args
|
| 47 |
+
self.classifier = classifier
|
| 48 |
+
self.num_cls = num_cls
|
| 49 |
+
|
| 50 |
+
if self.args.clean_data:
|
| 51 |
+
self.linear = nn.Embedding(self.alphabet_size, embedding_dim=args.hidden_dim)
|
| 52 |
+
else:
|
| 53 |
+
expanded_simplex_input = args.cls_expanded_simplex or not classifier and (args.mode == 'dirichlet' or args.mode == 'riemannian')
|
| 54 |
+
inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1)
|
| 55 |
+
if (args.mode == 'ardm' or args.mode == 'lrar') and not classifier:
|
| 56 |
+
inp_size += 1 # plus one for the mask token of these models
|
| 57 |
+
self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4)
|
| 58 |
+
self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim))
|
| 59 |
+
|
| 60 |
+
self.num_layers = 5 * args.num_cnn_stacks
|
| 61 |
+
self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 62 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 63 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
|
| 64 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
|
| 65 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)]
|
| 66 |
+
self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
|
| 67 |
+
self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
|
| 68 |
+
self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
|
| 69 |
+
self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
|
| 70 |
+
nn.ReLU(),
|
| 71 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim if classifier else self.alphabet_size, kernel_size=1))
|
| 72 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 73 |
+
if classifier:
|
| 74 |
+
self.cls_head = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim),
|
| 75 |
+
nn.ReLU(),
|
| 76 |
+
nn.Linear(args.hidden_dim, self.num_cls))
|
| 77 |
+
|
| 78 |
+
if self.args.cls_free_guidance and not self.classifier:
|
| 79 |
+
self.cls_embedder = nn.Embedding(num_embeddings=self.num_cls + 1, embedding_dim=args.hidden_dim)
|
| 80 |
+
self.cls_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
|
| 81 |
+
def forward(self, seq, t, cls = None, return_embedding=False):
|
| 82 |
+
# pdb.set_trace()
|
| 83 |
+
if self.args.clean_data:
|
| 84 |
+
feat = self.linear(seq)
|
| 85 |
+
feat = feat.permute(0, 2, 1)
|
| 86 |
+
else:
|
| 87 |
+
time_emb = F.relu(self.time_embedder(t))
|
| 88 |
+
feat = seq.permute(0, 2, 1)
|
| 89 |
+
feat = F.relu(self.linear(feat))
|
| 90 |
+
|
| 91 |
+
if self.args.cls_free_guidance and not self.classifier and cls is not None:
|
| 92 |
+
# pdb.set_trace()
|
| 93 |
+
cls_emb = self.cls_embedder(cls)
|
| 94 |
+
|
| 95 |
+
for i in range(self.num_layers):
|
| 96 |
+
h = self.dropout(feat.clone())
|
| 97 |
+
if not self.args.clean_data:
|
| 98 |
+
h = h + self.time_layers[i](time_emb)[:, :, None]
|
| 99 |
+
if self.args.cls_free_guidance and not self.classifier and cls is not None:
|
| 100 |
+
h = h + self.cls_layers[i](cls_emb)[:, :, None]
|
| 101 |
+
h = self.norms[i]((h).permute(0, 2, 1))
|
| 102 |
+
h = F.relu(self.convs[i](h.permute(0, 2, 1)))
|
| 103 |
+
if h.shape == feat.shape:
|
| 104 |
+
feat = h + feat
|
| 105 |
+
else:
|
| 106 |
+
feat = h
|
| 107 |
+
feat = self.final_conv(feat)
|
| 108 |
+
feat = feat.permute(0, 2, 1)
|
| 109 |
+
if self.classifier:
|
| 110 |
+
feat = feat.mean(dim=1)
|
| 111 |
+
if return_embedding:
|
| 112 |
+
embedding = self.cls_head[:1](feat)
|
| 113 |
+
return self.cls_head[1:](embedding), embedding
|
| 114 |
+
else:
|
| 115 |
+
return self.cls_head(feat)
|
| 116 |
+
return feat
|
models/enhancer_models.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import copy
|
| 6 |
+
import pdb
|
| 7 |
+
|
| 8 |
+
class GaussianFourierProjection(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Gaussian random features for encoding time steps.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, embed_dim, scale=30.):
|
| 14 |
+
super().__init__()
|
| 15 |
+
# Randomly sample weights during initialization. These weights are fixed
|
| 16 |
+
# during optimization and are not trainable.
|
| 17 |
+
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
| 21 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
| 22 |
+
|
| 23 |
+
class Dense(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
A fully connected layer that reshapes outputs to feature maps.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, input_dim, output_dim):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.dense = nn.Linear(input_dim, output_dim)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return self.dense(x)[...]
|
| 34 |
+
|
| 35 |
+
class Swish(nn.Module):
|
| 36 |
+
def __init__(self):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return torch.sigmoid(x) * x
|
| 41 |
+
|
| 42 |
+
class CNNModel(nn.Module):
|
| 43 |
+
"""A time-dependent score-based model built upon U-Net architecture."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
embed_dim (int): Dimensionality of the token and time embeddings.
|
| 49 |
+
"""
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.alphabet_size = alphabet_size
|
| 52 |
+
|
| 53 |
+
self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
|
| 54 |
+
|
| 55 |
+
self.time_embed = nn.Sequential(
|
| 56 |
+
GaussianFourierProjection(embed_dim=embed_dim),
|
| 57 |
+
nn.Linear(embed_dim, embed_dim)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.swish = Swish()
|
| 61 |
+
|
| 62 |
+
n = hidden_dim
|
| 63 |
+
|
| 64 |
+
self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
|
| 65 |
+
|
| 66 |
+
self.blocks = nn.ModuleList([
|
| 67 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 68 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 69 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 70 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 71 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 72 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 73 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 74 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 75 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 76 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 77 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 78 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 79 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 80 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 81 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 82 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 83 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 84 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 85 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 86 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
|
| 87 |
+
])
|
| 88 |
+
|
| 89 |
+
self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)])
|
| 90 |
+
self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)])
|
| 91 |
+
|
| 92 |
+
self.final = nn.Sequential(
|
| 93 |
+
nn.Conv1d(n, n, kernel_size=1),
|
| 94 |
+
nn.GELU(),
|
| 95 |
+
nn.Conv1d(n, self.alphabet_size, kernel_size=1)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def forward(self, x, t):
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
x: Tensor of shape (B, L) containing DNA token indices.
|
| 103 |
+
t: Tensor of shape (B,) containing the time steps.
|
| 104 |
+
Returns:
|
| 105 |
+
out: Tensor of shape (B, L, 4) with output logits for each DNA base.
|
| 106 |
+
"""
|
| 107 |
+
x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
|
| 108 |
+
|
| 109 |
+
time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
|
| 110 |
+
|
| 111 |
+
out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
|
| 112 |
+
out = self.swish(self.linear(out)) # (B, n, L)
|
| 113 |
+
|
| 114 |
+
# Process through convolutional blocks, adding time conditioning via dense layers.
|
| 115 |
+
for block, dense, norm in zip(self.blocks, self.denses, self.norms):
|
| 116 |
+
# dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
|
| 117 |
+
h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
|
| 118 |
+
# Residual connection if shapes match.
|
| 119 |
+
if h.shape == out.shape:
|
| 120 |
+
out = h + out
|
| 121 |
+
else:
|
| 122 |
+
out = h
|
| 123 |
+
|
| 124 |
+
out = self.final(out) # (B, 4, L)
|
| 125 |
+
out = out.permute(0, 2, 1) # (B, L, 4)
|
| 126 |
+
|
| 127 |
+
# Normalization
|
| 128 |
+
out = out - out.mean(dim=-1, keepdim=True)
|
| 129 |
+
return out
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class MLPModel(nn.Module):
|
| 133 |
+
def __init__(
|
| 134 |
+
self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.input_dim = input_dim
|
| 137 |
+
self.time_dim = time_dim
|
| 138 |
+
self.hidden_dim = hidden_dim
|
| 139 |
+
|
| 140 |
+
self.time_embedding = nn.Linear(1, time_dim)
|
| 141 |
+
self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim)
|
| 142 |
+
|
| 143 |
+
self.swish = Swish()
|
| 144 |
+
|
| 145 |
+
self.main = nn.Sequential(
|
| 146 |
+
self.swish,
|
| 147 |
+
nn.Linear(hidden_dim * length + time_dim, hidden_dim),
|
| 148 |
+
self.swish,
|
| 149 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 150 |
+
self.swish,
|
| 151 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 152 |
+
self.swish,
|
| 153 |
+
nn.Linear(hidden_dim, self.input_dim * length),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def forward(self, x, t):
|
| 157 |
+
'''
|
| 158 |
+
x shape (B,L)
|
| 159 |
+
t shape (B,)
|
| 160 |
+
'''
|
| 161 |
+
t = self.time_embedding(t.unsqueeze(-1))
|
| 162 |
+
x = self.token_embedding(x)
|
| 163 |
+
|
| 164 |
+
B, N, d = x.shape
|
| 165 |
+
x = x.reshape(B, N * d)
|
| 166 |
+
|
| 167 |
+
h = torch.cat([x, t], dim=1)
|
| 168 |
+
h = self.main(h)
|
| 169 |
+
|
| 170 |
+
h = h.reshape(B, N, self.input_dim)
|
| 171 |
+
|
| 172 |
+
return h
|
| 173 |
+
|
| 174 |
+
class DirichletCNNModel(nn.Module):
|
| 175 |
+
def __init__(self, args, alphabet_size):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.alphabet_size = alphabet_size
|
| 178 |
+
self.args = args
|
| 179 |
+
expanded_simplex_input = args.cls_expanded_simplex and (args.mode == 'dirichlet' or args.mode == 'riemannian')
|
| 180 |
+
inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1)
|
| 181 |
+
self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4)
|
| 182 |
+
self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim))
|
| 183 |
+
|
| 184 |
+
self.num_layers = 5 * args.num_cnn_stacks
|
| 185 |
+
self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 186 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 187 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
|
| 188 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
|
| 189 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)]
|
| 190 |
+
self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
|
| 191 |
+
self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
|
| 192 |
+
self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
|
| 193 |
+
self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
|
| 194 |
+
nn.ReLU(),
|
| 195 |
+
nn.Conv1d(args.hidden_dim, self.alphabet_size, kernel_size=1))
|
| 196 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 197 |
+
|
| 198 |
+
def forward(self, seq, t):
|
| 199 |
+
time_emb = F.relu(self.time_embedder(t))
|
| 200 |
+
feat = seq.permute(0, 2, 1)
|
| 201 |
+
feat = F.relu(self.linear(feat))
|
| 202 |
+
|
| 203 |
+
for i in range(self.num_layers):
|
| 204 |
+
h = self.dropout(feat.clone())
|
| 205 |
+
if not self.args.clean_data:
|
| 206 |
+
h = h + self.time_layers[i](time_emb)[:, :, None]
|
| 207 |
+
h = self.norms[i]((h).permute(0, 2, 1))
|
| 208 |
+
h = F.relu(self.convs[i](h.permute(0, 2, 1)))
|
| 209 |
+
if h.shape == feat.shape:
|
| 210 |
+
feat = h + feat
|
| 211 |
+
else:
|
| 212 |
+
feat = h
|
| 213 |
+
feat = self.final_conv(feat)
|
| 214 |
+
feat = feat.permute(0, 2, 1)
|
| 215 |
+
return feat
|
models/peptide_classifiers.py
ADDED
|
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdb
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
import time
|
| 7 |
+
from transformers import AutoModel, AutoConfig, AutoTokenizer
|
| 8 |
+
import xgboost as xgb
|
| 9 |
+
import esm
|
| 10 |
+
|
| 11 |
+
from flow_matching.path import MixtureDiscreteProbPath
|
| 12 |
+
from flow_matching.path.scheduler import PolynomialConvexScheduler
|
| 13 |
+
from flow_matching.solver import MixtureDiscreteEulerSolver
|
| 14 |
+
from flow_matching.utils import ModelWrapper
|
| 15 |
+
from flow_matching.loss import MixturePathGeneralizedKL
|
| 16 |
+
|
| 17 |
+
from models.peptide_models import CNNModel
|
| 18 |
+
from modules.bindevaluator_modules import *
|
| 19 |
+
|
| 20 |
+
def parse_motifs(motif: str) -> list:
|
| 21 |
+
parts = motif.split(',')
|
| 22 |
+
result = []
|
| 23 |
+
|
| 24 |
+
for part in parts:
|
| 25 |
+
part = part.strip()
|
| 26 |
+
if '-' in part:
|
| 27 |
+
start, end = map(int, part.split('-'))
|
| 28 |
+
result.extend(range(start, end + 1))
|
| 29 |
+
else:
|
| 30 |
+
result.append(int(part))
|
| 31 |
+
|
| 32 |
+
result = [pos-1 for pos in result]
|
| 33 |
+
print(f'Target Motifs: {result}')
|
| 34 |
+
return torch.tensor(result)
|
| 35 |
+
|
| 36 |
+
class BindEvaluator(pl.LightningModule):
|
| 37 |
+
def __init__(self, n_layers, d_model, d_hidden, n_head,
|
| 38 |
+
d_k, d_v, d_inner, dropout=0.2,
|
| 39 |
+
learning_rate=0.00001, max_epochs=15, kl_weight=1):
|
| 40 |
+
super(BindEvaluator, self).__init__()
|
| 41 |
+
|
| 42 |
+
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 43 |
+
self.esm_model.eval()
|
| 44 |
+
# freeze all the esm_model parameters
|
| 45 |
+
for param in self.esm_model.parameters():
|
| 46 |
+
param.requires_grad = False
|
| 47 |
+
|
| 48 |
+
self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
|
| 49 |
+
n_head, d_k, d_v, d_inner, dropout=dropout)
|
| 50 |
+
|
| 51 |
+
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
| 52 |
+
d_k, d_v, dropout=dropout)
|
| 53 |
+
|
| 54 |
+
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
| 55 |
+
|
| 56 |
+
self.output_projection_prot = nn.Linear(d_model, 1)
|
| 57 |
+
|
| 58 |
+
self.learning_rate = learning_rate
|
| 59 |
+
self.max_epochs = max_epochs
|
| 60 |
+
self.kl_weight = kl_weight
|
| 61 |
+
|
| 62 |
+
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
|
| 63 |
+
self.historical_memory = 0.9
|
| 64 |
+
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
|
| 65 |
+
|
| 66 |
+
def forward(self, binder_tokens, target_tokens):
|
| 67 |
+
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
|
| 68 |
+
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
|
| 69 |
+
|
| 70 |
+
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
| 71 |
+
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
| 72 |
+
protein_sequence)
|
| 73 |
+
|
| 74 |
+
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
| 75 |
+
|
| 76 |
+
prot_enc = self.final_ffn(prot_enc)
|
| 77 |
+
|
| 78 |
+
prot_enc = self.output_projection_prot(prot_enc)
|
| 79 |
+
|
| 80 |
+
return prot_enc
|
| 81 |
+
|
| 82 |
+
def get_probs(self, x_t, target_sequence):
|
| 83 |
+
'''
|
| 84 |
+
Inputs:
|
| 85 |
+
- xt: Shape (bsz, seq_len)
|
| 86 |
+
- target_sequence: Shape (1, tgt_len)
|
| 87 |
+
'''
|
| 88 |
+
# pdb.set_trace()
|
| 89 |
+
target_sequence = target_sequence.repeat(x_t.shape[0], 1)
|
| 90 |
+
binder_attention_mask = torch.ones_like(x_t)
|
| 91 |
+
target_attention_mask = torch.ones_like(target_sequence)
|
| 92 |
+
|
| 93 |
+
binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0
|
| 94 |
+
target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0
|
| 95 |
+
|
| 96 |
+
binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)}
|
| 97 |
+
target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)}
|
| 98 |
+
|
| 99 |
+
logits = self.forward(binder_tokens, target_tokens).squeeze(-1)
|
| 100 |
+
# pdb.set_trace()
|
| 101 |
+
logits[:, 0] = logits[:, -1] = -100 # float('-inf')
|
| 102 |
+
probs = torch.sigmoid(logits)
|
| 103 |
+
|
| 104 |
+
return probs # shape (bsz, tgt_len)
|
| 105 |
+
|
| 106 |
+
def motif_score(self, x_t, target_sequence, motifs):
|
| 107 |
+
probs = self.get_probs(x_t, target_sequence)
|
| 108 |
+
motif_probs = probs[:, motifs]
|
| 109 |
+
motif_score = motif_probs.sum(dim=-1) / len(motifs)
|
| 110 |
+
# pdb.set_trace()
|
| 111 |
+
return motif_score
|
| 112 |
+
|
| 113 |
+
def non_motif_score(self, x_t, target_sequence, motifs):
|
| 114 |
+
probs = self.get_probs(x_t, target_sequence)
|
| 115 |
+
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
|
| 116 |
+
mask = non_motif_probs >= 0.5
|
| 117 |
+
count = mask.sum(dim=-1)
|
| 118 |
+
|
| 119 |
+
non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
|
| 120 |
+
|
| 121 |
+
return non_motif_score
|
| 122 |
+
|
| 123 |
+
def scoring(self, x_t, target_sequence, motifs, penalty=False):
|
| 124 |
+
probs = self.get_probs(x_t, target_sequence)
|
| 125 |
+
motif_probs = probs[:, motifs]
|
| 126 |
+
motif_score = motif_probs.sum(dim=-1) / len(motifs)
|
| 127 |
+
# pdb.set_trace()
|
| 128 |
+
|
| 129 |
+
if penalty:
|
| 130 |
+
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
|
| 131 |
+
mask = non_motif_probs >= 0.5
|
| 132 |
+
count = mask.sum(dim=-1)
|
| 133 |
+
# non_motif_score = 1 - torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
|
| 134 |
+
non_motif_score = count / target_sequence.shape[1]
|
| 135 |
+
return motif_score, 1 - non_motif_score
|
| 136 |
+
else:
|
| 137 |
+
return motif_score
|
| 138 |
+
|
| 139 |
+
class MotifModel(nn.Module):
|
| 140 |
+
def __init__(self, bindevaluator, target_sequence, motifs, penalty=False):
|
| 141 |
+
super(MotifModel, self).__init__()
|
| 142 |
+
self.bindevaluator = bindevaluator
|
| 143 |
+
self.target_sequence = target_sequence
|
| 144 |
+
self.motifs = motifs
|
| 145 |
+
self.penalty = penalty
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
|
| 149 |
+
|
| 150 |
+
class UnpooledBindingPredictor(nn.Module):
|
| 151 |
+
def __init__(self,
|
| 152 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 153 |
+
hidden_dim=512,
|
| 154 |
+
kernel_sizes=[3, 5, 7],
|
| 155 |
+
n_heads=8,
|
| 156 |
+
n_layers=3,
|
| 157 |
+
dropout=0.1,
|
| 158 |
+
freeze_esm=True):
|
| 159 |
+
super().__init__()
|
| 160 |
+
|
| 161 |
+
# Define binding thresholds
|
| 162 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 163 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 164 |
+
|
| 165 |
+
# Load ESM model for computing embeddings on the fly
|
| 166 |
+
self.esm_model = AutoModel.from_pretrained(esm_model_name)
|
| 167 |
+
self.config = AutoConfig.from_pretrained(esm_model_name)
|
| 168 |
+
|
| 169 |
+
# Freeze ESM parameters if needed
|
| 170 |
+
if freeze_esm:
|
| 171 |
+
for param in self.esm_model.parameters():
|
| 172 |
+
param.requires_grad = False
|
| 173 |
+
|
| 174 |
+
# Get ESM hidden size
|
| 175 |
+
esm_dim = self.config.hidden_size
|
| 176 |
+
|
| 177 |
+
# Output channels for CNN layers
|
| 178 |
+
output_channels_per_kernel = 64
|
| 179 |
+
|
| 180 |
+
# CNN layers for handling variable length sequences
|
| 181 |
+
self.protein_conv_layers = nn.ModuleList([
|
| 182 |
+
nn.Conv1d(
|
| 183 |
+
in_channels=esm_dim,
|
| 184 |
+
out_channels=output_channels_per_kernel,
|
| 185 |
+
kernel_size=k,
|
| 186 |
+
padding='same'
|
| 187 |
+
) for k in kernel_sizes
|
| 188 |
+
])
|
| 189 |
+
|
| 190 |
+
self.binder_conv_layers = nn.ModuleList([
|
| 191 |
+
nn.Conv1d(
|
| 192 |
+
in_channels=esm_dim,
|
| 193 |
+
out_channels=output_channels_per_kernel,
|
| 194 |
+
kernel_size=k,
|
| 195 |
+
padding='same'
|
| 196 |
+
) for k in kernel_sizes
|
| 197 |
+
])
|
| 198 |
+
|
| 199 |
+
# Calculate total features after convolution and pooling
|
| 200 |
+
total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
|
| 201 |
+
|
| 202 |
+
# Project to same dimension after CNN processing
|
| 203 |
+
self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 204 |
+
self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 205 |
+
|
| 206 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 207 |
+
self.binder_norm = nn.LayerNorm(hidden_dim)
|
| 208 |
+
|
| 209 |
+
# Cross attention blocks with layer norm
|
| 210 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 211 |
+
nn.ModuleDict({
|
| 212 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 213 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 214 |
+
'ffn': nn.Sequential(
|
| 215 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 216 |
+
nn.ReLU(),
|
| 217 |
+
nn.Dropout(dropout),
|
| 218 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 219 |
+
),
|
| 220 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 221 |
+
}) for _ in range(n_layers)
|
| 222 |
+
])
|
| 223 |
+
|
| 224 |
+
# Prediction heads
|
| 225 |
+
self.shared_head = nn.Sequential(
|
| 226 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 227 |
+
nn.ReLU(),
|
| 228 |
+
nn.Dropout(dropout),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Regression head
|
| 232 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 233 |
+
|
| 234 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 235 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 236 |
+
|
| 237 |
+
def get_binding_class(self, affinity):
|
| 238 |
+
"""Convert affinity values to class indices
|
| 239 |
+
0: tight binding (>= 7.5)
|
| 240 |
+
1: medium binding (6.0-7.5)
|
| 241 |
+
2: weak binding (< 6.0)
|
| 242 |
+
"""
|
| 243 |
+
if isinstance(affinity, torch.Tensor):
|
| 244 |
+
tight_mask = affinity >= self.tight_threshold
|
| 245 |
+
weak_mask = affinity < self.weak_threshold
|
| 246 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 247 |
+
|
| 248 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 249 |
+
classes[medium_mask] = 1
|
| 250 |
+
classes[weak_mask] = 2
|
| 251 |
+
return classes
|
| 252 |
+
else:
|
| 253 |
+
if affinity >= self.tight_threshold:
|
| 254 |
+
return 0 # tight binding
|
| 255 |
+
elif affinity < self.weak_threshold:
|
| 256 |
+
return 2 # weak binding
|
| 257 |
+
else:
|
| 258 |
+
return 1 # medium binding
|
| 259 |
+
|
| 260 |
+
def compute_embeddings(self, input_ids, attention_mask=None):
|
| 261 |
+
"""Compute ESM embeddings on the fly"""
|
| 262 |
+
esm_outputs = self.esm_model(
|
| 263 |
+
input_ids=input_ids,
|
| 264 |
+
attention_mask=attention_mask,
|
| 265 |
+
return_dict=True
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
|
| 269 |
+
return esm_outputs.last_hidden_state
|
| 270 |
+
|
| 271 |
+
def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
|
| 272 |
+
"""Process a sequence through CNN layers and pooling"""
|
| 273 |
+
# Transpose for CNN: [batch_size, hidden_size, seq_length]
|
| 274 |
+
x = unpooled_emb.transpose(1, 2)
|
| 275 |
+
|
| 276 |
+
# Apply CNN layers and collect outputs
|
| 277 |
+
conv_outputs = []
|
| 278 |
+
for conv in conv_layers:
|
| 279 |
+
conv_out = F.relu(conv(x))
|
| 280 |
+
conv_outputs.append(conv_out)
|
| 281 |
+
|
| 282 |
+
# Concatenate along channel dimension
|
| 283 |
+
conv_output = torch.cat(conv_outputs, dim=1)
|
| 284 |
+
|
| 285 |
+
# Global pooling (both max and average)
|
| 286 |
+
# If attention mask is provided, use it to create a proper mask for pooling
|
| 287 |
+
if attention_mask is not None:
|
| 288 |
+
# Create a mask for pooling (1 for valid positions, 0 for padding)
|
| 289 |
+
# Expand mask to match conv_output channels
|
| 290 |
+
expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
|
| 291 |
+
|
| 292 |
+
# Apply mask (set padding to large negative value for max pooling)
|
| 293 |
+
masked_output = conv_output.clone()
|
| 294 |
+
masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
|
| 295 |
+
|
| 296 |
+
# Max pooling along sequence dimension
|
| 297 |
+
max_pooled = torch.max(masked_output, dim=2)[0]
|
| 298 |
+
|
| 299 |
+
# Average pooling (sum divided by number of valid positions)
|
| 300 |
+
sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
|
| 301 |
+
valid_positions = torch.sum(expanded_mask, dim=2)
|
| 302 |
+
valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
|
| 303 |
+
avg_pooled = sum_pooled / valid_positions
|
| 304 |
+
else:
|
| 305 |
+
# If no mask, use standard pooling
|
| 306 |
+
max_pooled = torch.max(conv_output, dim=2)[0]
|
| 307 |
+
avg_pooled = torch.mean(conv_output, dim=2)
|
| 308 |
+
|
| 309 |
+
# Concatenate the pooled features
|
| 310 |
+
pooled = torch.cat([max_pooled, avg_pooled], dim=1)
|
| 311 |
+
|
| 312 |
+
return pooled
|
| 313 |
+
|
| 314 |
+
def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
|
| 315 |
+
# Compute embeddings on the fly using the ESM model
|
| 316 |
+
protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
|
| 317 |
+
binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
|
| 318 |
+
|
| 319 |
+
# Process protein and binder sequences through CNN layers
|
| 320 |
+
protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
|
| 321 |
+
binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
|
| 322 |
+
|
| 323 |
+
# Project to same dimension
|
| 324 |
+
protein = self.protein_norm(self.protein_projection(protein_features))
|
| 325 |
+
binder = self.binder_norm(self.binder_projection(binder_features))
|
| 326 |
+
|
| 327 |
+
# Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
|
| 328 |
+
protein = protein.unsqueeze(0)
|
| 329 |
+
binder = binder.unsqueeze(0)
|
| 330 |
+
|
| 331 |
+
# Cross attention layers
|
| 332 |
+
for layer in self.cross_attention_layers:
|
| 333 |
+
# Protein attending to binder
|
| 334 |
+
attended_protein = layer['attention'](
|
| 335 |
+
protein, binder, binder
|
| 336 |
+
)[0]
|
| 337 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 338 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 339 |
+
|
| 340 |
+
# Binder attending to protein
|
| 341 |
+
attended_binder = layer['attention'](
|
| 342 |
+
binder, protein, protein
|
| 343 |
+
)[0]
|
| 344 |
+
binder = layer['norm1'](binder + attended_binder)
|
| 345 |
+
binder = layer['norm2'](binder + layer['ffn'](binder))
|
| 346 |
+
|
| 347 |
+
# Remove sequence dimension
|
| 348 |
+
protein_pool = protein.squeeze(0)
|
| 349 |
+
binder_pool = binder.squeeze(0)
|
| 350 |
+
|
| 351 |
+
# Concatenate both representations
|
| 352 |
+
combined = torch.cat([protein_pool, binder_pool], dim=-1)
|
| 353 |
+
|
| 354 |
+
# Shared features
|
| 355 |
+
shared_features = self.shared_head(combined)
|
| 356 |
+
|
| 357 |
+
regression_output = self.regression_head(shared_features)
|
| 358 |
+
# classification_logits = self.classification_head(shared_features)
|
| 359 |
+
|
| 360 |
+
# return regression_output, classification_logits
|
| 361 |
+
return regression_output
|
| 362 |
+
|
| 363 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 364 |
+
def __init__(self,
|
| 365 |
+
esm_dim=1280,
|
| 366 |
+
smiles_dim=1280,
|
| 367 |
+
hidden_dim=512,
|
| 368 |
+
n_heads=8,
|
| 369 |
+
n_layers=5,
|
| 370 |
+
dropout=0.1):
|
| 371 |
+
super().__init__()
|
| 372 |
+
|
| 373 |
+
# Define binding thresholds
|
| 374 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 375 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 376 |
+
|
| 377 |
+
# Project to same dimension
|
| 378 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 379 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 380 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 381 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 382 |
+
|
| 383 |
+
# Cross attention blocks with layer norm
|
| 384 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 385 |
+
nn.ModuleDict({
|
| 386 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 387 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 388 |
+
'ffn': nn.Sequential(
|
| 389 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 390 |
+
nn.ReLU(),
|
| 391 |
+
nn.Dropout(dropout),
|
| 392 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 393 |
+
),
|
| 394 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 395 |
+
}) for _ in range(n_layers)
|
| 396 |
+
])
|
| 397 |
+
|
| 398 |
+
# Prediction heads
|
| 399 |
+
self.shared_head = nn.Sequential(
|
| 400 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 401 |
+
nn.ReLU(),
|
| 402 |
+
nn.Dropout(dropout),
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Regression head
|
| 406 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 407 |
+
|
| 408 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 409 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 410 |
+
|
| 411 |
+
def get_binding_class(self, affinity):
|
| 412 |
+
"""Convert affinity values to class indices
|
| 413 |
+
0: tight binding (>= 7.5)
|
| 414 |
+
1: medium binding (6.0-7.5)
|
| 415 |
+
2: weak binding (< 6.0)
|
| 416 |
+
"""
|
| 417 |
+
if isinstance(affinity, torch.Tensor):
|
| 418 |
+
tight_mask = affinity >= self.tight_threshold
|
| 419 |
+
weak_mask = affinity < self.weak_threshold
|
| 420 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 421 |
+
|
| 422 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 423 |
+
classes[medium_mask] = 1
|
| 424 |
+
classes[weak_mask] = 2
|
| 425 |
+
return classes
|
| 426 |
+
else:
|
| 427 |
+
if affinity >= self.tight_threshold:
|
| 428 |
+
return 0 # tight binding
|
| 429 |
+
elif affinity < self.weak_threshold:
|
| 430 |
+
return 2 # weak binding
|
| 431 |
+
else:
|
| 432 |
+
return 1 # medium binding
|
| 433 |
+
|
| 434 |
+
def forward(self, protein_emb, binder_emb):
|
| 435 |
+
|
| 436 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 437 |
+
smiles = self.smiles_norm(self.smiles_projection(binder_emb))
|
| 438 |
+
|
| 439 |
+
protein = protein.transpose(0, 1)
|
| 440 |
+
smiles = smiles.transpose(0, 1)
|
| 441 |
+
|
| 442 |
+
# Cross attention layers
|
| 443 |
+
for layer in self.cross_attention_layers:
|
| 444 |
+
# Protein attending to SMILES
|
| 445 |
+
attended_protein = layer['attention'](
|
| 446 |
+
protein, smiles, smiles
|
| 447 |
+
)[0]
|
| 448 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 449 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 450 |
+
|
| 451 |
+
# SMILES attending to protein
|
| 452 |
+
attended_smiles = layer['attention'](
|
| 453 |
+
smiles, protein, protein
|
| 454 |
+
)[0]
|
| 455 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 456 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 457 |
+
|
| 458 |
+
# Get sequence-level representations
|
| 459 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 460 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 461 |
+
|
| 462 |
+
# Concatenate both representations
|
| 463 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 464 |
+
|
| 465 |
+
# Shared features
|
| 466 |
+
shared_features = self.shared_head(combined)
|
| 467 |
+
|
| 468 |
+
regression_output = self.regression_head(shared_features)
|
| 469 |
+
|
| 470 |
+
return regression_output
|
| 471 |
+
|
| 472 |
+
class PooledAffinityModel(nn.Module):
|
| 473 |
+
def __init__(self, affinity_predictor, target_sequence):
|
| 474 |
+
super(PooledAffinityModel, self).__init__()
|
| 475 |
+
self.affinity_predictor = affinity_predictor
|
| 476 |
+
self.target_sequence = target_sequence
|
| 477 |
+
self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device)
|
| 478 |
+
for param in self.esm_model.parameters():
|
| 479 |
+
param.requires_grad = False
|
| 480 |
+
|
| 481 |
+
def compute_embeddings(self, input_ids, attention_mask=None):
|
| 482 |
+
"""Compute ESM embeddings on the fly"""
|
| 483 |
+
esm_outputs = self.esm_model(
|
| 484 |
+
input_ids=input_ids,
|
| 485 |
+
attention_mask=attention_mask,
|
| 486 |
+
return_dict=True
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
|
| 490 |
+
return esm_outputs.last_hidden_state
|
| 491 |
+
|
| 492 |
+
def forward(self, x):
|
| 493 |
+
target_sequence = self.target_sequence.repeat(x.shape[0], 1)
|
| 494 |
+
|
| 495 |
+
protein_emb = self.compute_embeddings(input_ids=target_sequence)
|
| 496 |
+
binder_emb = self.compute_embeddings(input_ids=x)
|
| 497 |
+
return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1)
|
| 498 |
+
|
| 499 |
+
class AffinityModel(nn.Module):
|
| 500 |
+
def __init__(self, affinity_predictor, target_sequence):
|
| 501 |
+
super(AffinityModel, self).__init__()
|
| 502 |
+
self.affinity_predictor = affinity_predictor
|
| 503 |
+
self.target_sequence = target_sequence
|
| 504 |
+
|
| 505 |
+
def forward(self, x):
|
| 506 |
+
target_sequence = self.target_sequence.repeat(x.shape[0], 1)
|
| 507 |
+
affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1)
|
| 508 |
+
return affinity / 10
|
| 509 |
+
|
| 510 |
+
class HemolysisModel:
|
| 511 |
+
def __init__(self, device):
|
| 512 |
+
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_hemolysis.json')
|
| 513 |
+
|
| 514 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 515 |
+
self.model.eval()
|
| 516 |
+
|
| 517 |
+
self.device = device
|
| 518 |
+
|
| 519 |
+
def generate_embeddings(self, sequences):
|
| 520 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 521 |
+
with torch.no_grad():
|
| 522 |
+
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
|
| 523 |
+
embeddings = embeddings.cpu().numpy()
|
| 524 |
+
|
| 525 |
+
return embeddings
|
| 526 |
+
|
| 527 |
+
def get_scores(self, input_seqs):
|
| 528 |
+
scores = np.ones(len(input_seqs))
|
| 529 |
+
features = self.generate_embeddings(input_seqs)
|
| 530 |
+
|
| 531 |
+
if len(features) == 0:
|
| 532 |
+
return scores
|
| 533 |
+
|
| 534 |
+
features = np.nan_to_num(features, nan=0.)
|
| 535 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 536 |
+
|
| 537 |
+
features = xgb.DMatrix(features)
|
| 538 |
+
|
| 539 |
+
probs = self.predictor.predict(features)
|
| 540 |
+
# return the probability of it being not hemolytic
|
| 541 |
+
return torch.from_numpy(scores - probs).to(self.device)
|
| 542 |
+
|
| 543 |
+
def __call__(self, input_seqs: list):
|
| 544 |
+
scores = self.get_scores(input_seqs)
|
| 545 |
+
return scores
|
| 546 |
+
|
| 547 |
+
class NonfoulingModel:
|
| 548 |
+
def __init__(self, device):
|
| 549 |
+
# change model path
|
| 550 |
+
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_nonfouling.json')
|
| 551 |
+
|
| 552 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 553 |
+
self.model.eval()
|
| 554 |
+
|
| 555 |
+
self.device = device
|
| 556 |
+
|
| 557 |
+
def generate_embeddings(self, sequences):
|
| 558 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 559 |
+
with torch.no_grad():
|
| 560 |
+
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
|
| 561 |
+
embeddings = embeddings.cpu().numpy()
|
| 562 |
+
|
| 563 |
+
return embeddings
|
| 564 |
+
|
| 565 |
+
def get_scores(self, input_seqs):
|
| 566 |
+
scores = np.zeros(len(input_seqs))
|
| 567 |
+
features = self.generate_embeddings(input_seqs)
|
| 568 |
+
|
| 569 |
+
if len(features) == 0:
|
| 570 |
+
return scores
|
| 571 |
+
|
| 572 |
+
features = np.nan_to_num(features, nan=0.)
|
| 573 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 574 |
+
|
| 575 |
+
features = xgb.DMatrix(features)
|
| 576 |
+
|
| 577 |
+
scores = self.predictor.predict(features)
|
| 578 |
+
return torch.from_numpy(scores).to(self.device)
|
| 579 |
+
|
| 580 |
+
def __call__(self, input_seqs: list):
|
| 581 |
+
scores = self.get_scores(input_seqs)
|
| 582 |
+
return scores
|
| 583 |
+
|
| 584 |
+
class SolubilityModel:
|
| 585 |
+
def __init__(self, device):
|
| 586 |
+
# change model path
|
| 587 |
+
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
|
| 588 |
+
|
| 589 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 590 |
+
self.model.eval()
|
| 591 |
+
|
| 592 |
+
self.device = device
|
| 593 |
+
|
| 594 |
+
def generate_embeddings(self, sequences):
|
| 595 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 596 |
+
with torch.no_grad():
|
| 597 |
+
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
|
| 598 |
+
embeddings = embeddings.cpu().numpy()
|
| 599 |
+
|
| 600 |
+
return embeddings
|
| 601 |
+
|
| 602 |
+
def get_scores(self, input_seqs: list):
|
| 603 |
+
scores = np.zeros(len(input_seqs))
|
| 604 |
+
features = self.generate_embeddings(input_seqs)
|
| 605 |
+
|
| 606 |
+
if len(features) == 0:
|
| 607 |
+
return scores
|
| 608 |
+
|
| 609 |
+
features = np.nan_to_num(features, nan=0.)
|
| 610 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 611 |
+
|
| 612 |
+
features = xgb.DMatrix(features)
|
| 613 |
+
|
| 614 |
+
scores = self.predictor.predict(features)
|
| 615 |
+
return torch.from_numpy(scores).to(self.device)
|
| 616 |
+
|
| 617 |
+
def __call__(self, input_seqs: list):
|
| 618 |
+
scores = self.get_scores(input_seqs)
|
| 619 |
+
return scores
|
| 620 |
+
|
| 621 |
+
class SolubilityModelNew:
|
| 622 |
+
def __init__(self, device):
|
| 623 |
+
self.hydro_ids = torch.tensor([5, 7, 4, 12, 20, 18, 22, 14], device=device)
|
| 624 |
+
self.device = device
|
| 625 |
+
|
| 626 |
+
def get_scores(self, x):
|
| 627 |
+
mask = (x.unsqueeze(-1) == self.hydro_ids).any(dim=-1)
|
| 628 |
+
ratios = mask.float().mean(dim=1)
|
| 629 |
+
return 1 - ratios
|
| 630 |
+
|
| 631 |
+
def __call__(self, input_seqs: list):
|
| 632 |
+
scores = self.get_scores(input_seqs)
|
| 633 |
+
return scores
|
| 634 |
+
|
| 635 |
+
class PeptideCNN(nn.Module):
|
| 636 |
+
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
|
| 637 |
+
super().__init__()
|
| 638 |
+
self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
|
| 639 |
+
self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
|
| 640 |
+
self.fc = nn.Linear(hidden_dims[1], output_dim)
|
| 641 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 642 |
+
self.predictor = nn.Linear(output_dim, 1) # For regression/classification
|
| 643 |
+
|
| 644 |
+
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 645 |
+
self.esm_model.eval()
|
| 646 |
+
|
| 647 |
+
def forward(self, input_ids, attention_mask=None, return_features=False):
|
| 648 |
+
with torch.no_grad():
|
| 649 |
+
x = self.esm_model(input_ids, attention_mask).last_hidden_state
|
| 650 |
+
# x shape: (B, L, input_dim)
|
| 651 |
+
x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
|
| 652 |
+
x = nn.functional.relu(self.conv1(x))
|
| 653 |
+
x = self.dropout(x)
|
| 654 |
+
x = nn.functional.relu(self.conv2(x))
|
| 655 |
+
x = self.dropout(x)
|
| 656 |
+
x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
|
| 657 |
+
|
| 658 |
+
# Global average pooling over the sequence dimension (L)
|
| 659 |
+
x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
|
| 660 |
+
|
| 661 |
+
features = self.fc(x) # features shape: (B, output_dim)
|
| 662 |
+
if return_features:
|
| 663 |
+
return features
|
| 664 |
+
return self.predictor(features) # Output shape: (B, 1)
|
| 665 |
+
|
| 666 |
+
class HalfLifeModel:
|
| 667 |
+
def __init__(self, device):
|
| 668 |
+
input_dim = 1280
|
| 669 |
+
hidden_dims = [input_dim // 2, input_dim // 4]
|
| 670 |
+
output_dim = input_dim // 8
|
| 671 |
+
dropout_rate = 0.3
|
| 672 |
+
self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
|
| 673 |
+
self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
|
| 674 |
+
self.model.eval()
|
| 675 |
+
|
| 676 |
+
def __call__(self, x):
|
| 677 |
+
prediction = self.model(x, return_features=False)
|
| 678 |
+
halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
|
| 679 |
+
return halflife / 2
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def load_bindevaluator(checkpoint_path, device):
|
| 683 |
+
bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
|
| 684 |
+
bindevaluator.eval()
|
| 685 |
+
for param in bindevaluator.parameters():
|
| 686 |
+
param.requires_grad = False
|
| 687 |
+
|
| 688 |
+
return bindevaluator
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def load_solver(checkpoint_path, vocab_size, device):
|
| 692 |
+
lr = 1e-4
|
| 693 |
+
epochs = 200
|
| 694 |
+
embed_dim = 512
|
| 695 |
+
hidden_dim = 256
|
| 696 |
+
epsilon = 1e-3
|
| 697 |
+
batch_size = 256
|
| 698 |
+
warmup_epochs = epochs // 10
|
| 699 |
+
device = 'cuda:0'
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device)
|
| 703 |
+
probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))
|
| 704 |
+
probability_denoiser.eval()
|
| 705 |
+
for param in probability_denoiser.parameters():
|
| 706 |
+
param.requires_grad = False
|
| 707 |
+
|
| 708 |
+
# instantiate a convex path object
|
| 709 |
+
scheduler = PolynomialConvexScheduler(n=2.0)
|
| 710 |
+
path = MixtureDiscreteProbPath(scheduler=scheduler)
|
| 711 |
+
|
| 712 |
+
class WrappedModel(ModelWrapper):
|
| 713 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
|
| 714 |
+
return torch.softmax(self.model(x, t), dim=-1)
|
| 715 |
+
|
| 716 |
+
wrapped_probability_denoiser = WrappedModel(probability_denoiser)
|
| 717 |
+
solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
|
| 718 |
+
|
| 719 |
+
return solver
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def load_pooled_affinity_predictor(checkpoint_path, device):
|
| 723 |
+
"""Load trained model from checkpoint."""
|
| 724 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 725 |
+
|
| 726 |
+
model = ImprovedBindingPredictor().to(device)
|
| 727 |
+
|
| 728 |
+
# Load the trained weights
|
| 729 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 730 |
+
model.eval() # Set to evaluation mode
|
| 731 |
+
|
| 732 |
+
return model
|
| 733 |
+
|
| 734 |
+
def load_affinity_predictor(checkpoint_path, device):
|
| 735 |
+
"""Load trained model from checkpoint."""
|
| 736 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 737 |
+
|
| 738 |
+
model = UnpooledBindingPredictor(
|
| 739 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 740 |
+
hidden_dim=384,
|
| 741 |
+
kernel_sizes=[3, 5, 7],
|
| 742 |
+
n_heads=8,
|
| 743 |
+
n_layers=4,
|
| 744 |
+
dropout=0.14561457009902096,
|
| 745 |
+
freeze_esm=True
|
| 746 |
+
).to(device)
|
| 747 |
+
|
| 748 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 749 |
+
model.eval()
|
| 750 |
+
|
| 751 |
+
return model
|
models/peptide_models.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModel
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import esm
|
| 7 |
+
import copy
|
| 8 |
+
import pdb
|
| 9 |
+
|
| 10 |
+
class GaussianFourierProjection(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Gaussian random features for encoding time steps.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, embed_dim, scale=30.):
|
| 16 |
+
super().__init__()
|
| 17 |
+
# Randomly sample weights during initialization. These weights are fixed
|
| 18 |
+
# during optimization and are not trainable.
|
| 19 |
+
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
| 23 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
| 24 |
+
|
| 25 |
+
class Dense(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
A fully connected layer that reshapes outputs to feature maps.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, input_dim, output_dim):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.dense = nn.Linear(input_dim, output_dim)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return self.dense(x)[...]
|
| 36 |
+
|
| 37 |
+
class Swish(nn.Module):
|
| 38 |
+
def __init__(self):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return torch.sigmoid(x) * x
|
| 43 |
+
|
| 44 |
+
class CNNESMModel(nn.Module):
|
| 45 |
+
"""A time-dependent score-based model built upon U-Net architecture."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
embed_dim (int): Dimensionality of the token and time embeddings.
|
| 51 |
+
"""
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.alphabet_size = alphabet_size
|
| 54 |
+
|
| 55 |
+
# self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
|
| 56 |
+
self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 57 |
+
self.esm.eval()
|
| 58 |
+
for param in self.esm.parameters():
|
| 59 |
+
param.requires_grad = False
|
| 60 |
+
|
| 61 |
+
self.time_embed = nn.Sequential(
|
| 62 |
+
GaussianFourierProjection(embed_dim=embed_dim),
|
| 63 |
+
nn.Linear(embed_dim, embed_dim)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.swish = Swish()
|
| 67 |
+
|
| 68 |
+
n = hidden_dim
|
| 69 |
+
|
| 70 |
+
self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
|
| 71 |
+
|
| 72 |
+
self.blocks = nn.ModuleList([
|
| 73 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 74 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 75 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 76 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 77 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 78 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 79 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 80 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 81 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 82 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 83 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 84 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 85 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 86 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 87 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 88 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 89 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 90 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 91 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 92 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)])
|
| 96 |
+
self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)])
|
| 97 |
+
|
| 98 |
+
self.final = nn.Sequential(
|
| 99 |
+
nn.Conv1d(n, n, kernel_size=1),
|
| 100 |
+
nn.GELU(),
|
| 101 |
+
nn.Conv1d(n, self.alphabet_size, kernel_size=1)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def forward(self, x, t):
|
| 106 |
+
"""
|
| 107 |
+
Args:
|
| 108 |
+
x: Tensor of shape (B, L) containing DNA token indices.
|
| 109 |
+
t: Tensor of shape (B,) containing the time steps.
|
| 110 |
+
Returns:
|
| 111 |
+
out: Tensor of shape (B, L, 4) with output logits for each DNA base.
|
| 112 |
+
"""
|
| 113 |
+
# x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
x = self.esm(input_ids=x).last_hidden_state
|
| 116 |
+
time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
|
| 117 |
+
|
| 118 |
+
out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
|
| 119 |
+
out = self.swish(self.linear(out)) # (B, n, L)
|
| 120 |
+
|
| 121 |
+
# Process through convolutional blocks, adding time conditioning via dense layers.
|
| 122 |
+
for block, dense, norm in zip(self.blocks, self.denses, self.norms):
|
| 123 |
+
# dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
|
| 124 |
+
h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
|
| 125 |
+
# Residual connection if shapes match.
|
| 126 |
+
if h.shape == out.shape:
|
| 127 |
+
out = h + out
|
| 128 |
+
else:
|
| 129 |
+
out = h
|
| 130 |
+
|
| 131 |
+
out = self.final(out) # (B, 4, L)
|
| 132 |
+
out = out.permute(0, 2, 1) # (B, L, 4)
|
| 133 |
+
|
| 134 |
+
# Normalization
|
| 135 |
+
out = out - out.mean(dim=-1, keepdim=True)
|
| 136 |
+
return out
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class MLPModel(nn.Module):
|
| 140 |
+
def __init__(
|
| 141 |
+
self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.input_dim = input_dim
|
| 144 |
+
self.time_dim = time_dim
|
| 145 |
+
self.hidden_dim = hidden_dim
|
| 146 |
+
|
| 147 |
+
self.time_embedding = nn.Linear(1, time_dim)
|
| 148 |
+
self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim)
|
| 149 |
+
|
| 150 |
+
self.swish = Swish()
|
| 151 |
+
|
| 152 |
+
self.main = nn.Sequential(
|
| 153 |
+
self.swish,
|
| 154 |
+
nn.Linear(hidden_dim * length + time_dim, hidden_dim),
|
| 155 |
+
self.swish,
|
| 156 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 157 |
+
self.swish,
|
| 158 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 159 |
+
self.swish,
|
| 160 |
+
nn.Linear(hidden_dim, self.input_dim * length),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def forward(self, x, t):
|
| 164 |
+
'''
|
| 165 |
+
x shape (B,L)
|
| 166 |
+
t shape (B,)
|
| 167 |
+
'''
|
| 168 |
+
t = self.time_embedding(t.unsqueeze(-1))
|
| 169 |
+
x = self.token_embedding(x)
|
| 170 |
+
|
| 171 |
+
B, N, d = x.shape
|
| 172 |
+
x = x.reshape(B, N * d)
|
| 173 |
+
|
| 174 |
+
h = torch.cat([x, t], dim=1)
|
| 175 |
+
h = self.main(h)
|
| 176 |
+
|
| 177 |
+
h = h.reshape(B, N, self.input_dim)
|
| 178 |
+
|
| 179 |
+
return h
|
| 180 |
+
|
| 181 |
+
class CNNModel(nn.Module):
|
| 182 |
+
"""A time-dependent score-based model built upon U-Net architecture."""
|
| 183 |
+
|
| 184 |
+
def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
|
| 185 |
+
"""
|
| 186 |
+
Args:
|
| 187 |
+
embed_dim (int): Dimensionality of the token and time embeddings.
|
| 188 |
+
"""
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.alphabet_size = alphabet_size
|
| 191 |
+
|
| 192 |
+
self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
|
| 193 |
+
# self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 194 |
+
# self.esm.eval()
|
| 195 |
+
# for param in self.esm.parameters():
|
| 196 |
+
# param.requires_grad = False
|
| 197 |
+
|
| 198 |
+
self.time_embed = nn.Sequential(
|
| 199 |
+
GaussianFourierProjection(embed_dim=embed_dim),
|
| 200 |
+
nn.Linear(embed_dim, embed_dim)
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.swish = Swish()
|
| 204 |
+
|
| 205 |
+
n = hidden_dim
|
| 206 |
+
|
| 207 |
+
self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
|
| 208 |
+
|
| 209 |
+
self.blocks = nn.ModuleList([
|
| 210 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 211 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 212 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 213 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 214 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 215 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 216 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 217 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 218 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 219 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 220 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 221 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 222 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 223 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 224 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 225 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 226 |
+
# nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 227 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 228 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 229 |
+
# nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
|
| 230 |
+
])
|
| 231 |
+
|
| 232 |
+
self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)])
|
| 233 |
+
self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)])
|
| 234 |
+
|
| 235 |
+
self.final = nn.Sequential(
|
| 236 |
+
nn.Conv1d(n, n, kernel_size=1),
|
| 237 |
+
nn.GELU(),
|
| 238 |
+
nn.Conv1d(n, self.alphabet_size, kernel_size=1)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def forward(self, x, t):
|
| 242 |
+
"""
|
| 243 |
+
Args:
|
| 244 |
+
x: Tensor of shape (B, L) containing DNA token indices.
|
| 245 |
+
t: Tensor of shape (B,) containing the time steps.
|
| 246 |
+
Returns:
|
| 247 |
+
out: Tensor of shape (B, L, 4) with output logits for each DNA base.
|
| 248 |
+
"""
|
| 249 |
+
x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
|
| 250 |
+
# with torch.no_grad():
|
| 251 |
+
# x = self.esm(input_ids=x).last_hidden_state
|
| 252 |
+
time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
|
| 253 |
+
|
| 254 |
+
out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
|
| 255 |
+
out = self.swish(self.linear(out)) # (B, n, L)
|
| 256 |
+
|
| 257 |
+
# Process through convolutional blocks, adding time conditioning via dense layers.
|
| 258 |
+
for block, dense, norm in zip(self.blocks, self.denses, self.norms):
|
| 259 |
+
# dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
|
| 260 |
+
h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
|
| 261 |
+
# Residual connection if shapes match.
|
| 262 |
+
if h.shape == out.shape:
|
| 263 |
+
out = h + out
|
| 264 |
+
else:
|
| 265 |
+
out = h
|
| 266 |
+
|
| 267 |
+
out = self.final(out) # (B, 4, L)
|
| 268 |
+
out = out.permute(0, 2, 1) # (B, L, 4)
|
| 269 |
+
|
| 270 |
+
# Normalization
|
| 271 |
+
out = out - out.mean(dim=-1, keepdim=True)
|
| 272 |
+
return out
|
| 273 |
+
|
| 274 |
+
class CNNModel_Large(nn.Module):
|
| 275 |
+
"""A time-dependent score-based model built upon U-Net architecture."""
|
| 276 |
+
|
| 277 |
+
def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
|
| 278 |
+
"""
|
| 279 |
+
Args:
|
| 280 |
+
embed_dim (int): Dimensionality of the token and time embeddings.
|
| 281 |
+
"""
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.alphabet_size = alphabet_size
|
| 284 |
+
|
| 285 |
+
self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
|
| 286 |
+
|
| 287 |
+
self.time_embed = nn.Sequential(
|
| 288 |
+
GaussianFourierProjection(embed_dim=embed_dim),
|
| 289 |
+
nn.Linear(embed_dim, embed_dim)
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
self.swish = Swish()
|
| 293 |
+
|
| 294 |
+
n = hidden_dim
|
| 295 |
+
|
| 296 |
+
self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
|
| 297 |
+
|
| 298 |
+
self.blocks = nn.ModuleList([
|
| 299 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 300 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 301 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 302 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 303 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 304 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 305 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 306 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 307 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 308 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 309 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 310 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 311 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 312 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 313 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
|
| 314 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 315 |
+
nn.Conv1d(n, n, kernel_size=9, padding=4),
|
| 316 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
|
| 317 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
|
| 318 |
+
nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
|
| 319 |
+
])
|
| 320 |
+
|
| 321 |
+
self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(20)])
|
| 322 |
+
self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(20)])
|
| 323 |
+
|
| 324 |
+
self.final = nn.Sequential(
|
| 325 |
+
nn.Conv1d(n, n, kernel_size=1),
|
| 326 |
+
nn.GELU(),
|
| 327 |
+
nn.Conv1d(n, self.alphabet_size, kernel_size=1)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def forward(self, x, t):
|
| 331 |
+
"""
|
| 332 |
+
Args:
|
| 333 |
+
x: Tensor of shape (B, L) containing DNA token indices.
|
| 334 |
+
t: Tensor of shape (B,) containing the time steps.
|
| 335 |
+
Returns:
|
| 336 |
+
out: Tensor of shape (B, L, 4) with output logits for each DNA base.
|
| 337 |
+
"""
|
| 338 |
+
x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
|
| 339 |
+
time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
|
| 340 |
+
|
| 341 |
+
out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
|
| 342 |
+
out = self.swish(self.linear(out)) # (B, n, L)
|
| 343 |
+
|
| 344 |
+
# Process through convolutional blocks, adding time conditioning via dense layers.
|
| 345 |
+
for block, dense, norm in zip(self.blocks, self.denses, self.norms):
|
| 346 |
+
# dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
|
| 347 |
+
h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
|
| 348 |
+
# Residual connection if shapes match.
|
| 349 |
+
if h.shape == out.shape:
|
| 350 |
+
out = h + out
|
| 351 |
+
else:
|
| 352 |
+
out = h
|
| 353 |
+
|
| 354 |
+
out = self.final(out) # (B, 4, L)
|
| 355 |
+
out = out.permute(0, 2, 1) # (B, L, 4)
|
| 356 |
+
|
| 357 |
+
# Normalization
|
| 358 |
+
out = out - out.mean(dim=-1, keepdim=True)
|
| 359 |
+
return out
|
modules/bindevaluator_modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import *
|
| 2 |
+
from .score_domain import *
|
| 3 |
+
from .dataloaders import *
|
modules/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (234 Bytes). View file
|
|
|
modules/bindevaluator_modules/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (228 Bytes). View file
|
|
|
modules/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (234 Bytes). View file
|
|
|
modules/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc
ADDED
|
Binary file (7.93 kB). View file
|
|
|
modules/bindevaluator_modules/__pycache__/dataloaders.cpython-38.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
modules/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|
modules/bindevaluator_modules/__pycache__/layers.cpython-310.pyc
ADDED
|
Binary file (3.58 kB). View file
|
|
|
modules/bindevaluator_modules/__pycache__/layers.cpython-38.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|