AlienChen commited on
Commit
3527383
·
verified ·
1 Parent(s): ccef67d

Upload 72 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. bindevaluator.py +182 -0
  2. classifier_code/__init__.py +0 -0
  3. classifier_code/binding_affinity_unpooled.py +321 -0
  4. classifier_code/binding_affinity_unpooled_2.py +356 -0
  5. classifier_code/half_life.py +65 -0
  6. classifier_code/hemolysis_wt.py +101 -0
  7. classifier_code/nonfouling_wt.py +98 -0
  8. classifier_code/solubility_wt.py +98 -0
  9. flow_matching/__init__.py +7 -0
  10. flow_matching/loss/__init__.py +11 -0
  11. flow_matching/loss/generalized_loss.py +83 -0
  12. flow_matching/path/__init__.py +22 -0
  13. flow_matching/path/affine.py +260 -0
  14. flow_matching/path/geodesic.py +100 -0
  15. flow_matching/path/mixture.py +117 -0
  16. flow_matching/path/path.py +61 -0
  17. flow_matching/path/path_sample.py +53 -0
  18. flow_matching/path/scheduler/__init__.py +29 -0
  19. flow_matching/path/scheduler/schedule_transform.py +148 -0
  20. flow_matching/path/scheduler/scheduler.py +199 -0
  21. flow_matching/solver/__init__.py +18 -0
  22. flow_matching/solver/discrete_solver.py +428 -0
  23. flow_matching/solver/ode_solver.py +197 -0
  24. flow_matching/solver/riemannian_ode_solver.py +261 -0
  25. flow_matching/solver/solver.py +17 -0
  26. flow_matching/solver/utils.py +19 -0
  27. flow_matching/utils/__init__.py +17 -0
  28. flow_matching/utils/categorical_sampler.py +23 -0
  29. flow_matching/utils/manifolds/__init__.py +18 -0
  30. flow_matching/utils/manifolds/manifold.py +93 -0
  31. flow_matching/utils/manifolds/sphere.py +45 -0
  32. flow_matching/utils/manifolds/torus.py +28 -0
  33. flow_matching/utils/manifolds/utils.py +45 -0
  34. flow_matching/utils/model_wrapper.py +43 -0
  35. flow_matching/utils/multi_guidance.py +216 -0
  36. flow_matching/utils/multi_guidance_cnp.py +217 -0
  37. flow_matching/utils/utils.py +90 -0
  38. models/classifier.py +116 -0
  39. models/enhancer_models.py +215 -0
  40. models/peptide_classifiers.py +751 -0
  41. models/peptide_models.py +359 -0
  42. modules/bindevaluator_modules/__init__.py +3 -0
  43. modules/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc +0 -0
  44. modules/bindevaluator_modules/__pycache__/__init__.cpython-38.pyc +0 -0
  45. modules/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc +0 -0
  46. modules/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc +0 -0
  47. modules/bindevaluator_modules/__pycache__/dataloaders.cpython-38.pyc +0 -0
  48. modules/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc +0 -0
  49. modules/bindevaluator_modules/__pycache__/layers.cpython-310.pyc +0 -0
  50. modules/bindevaluator_modules/__pycache__/layers.cpython-38.pyc +0 -0
bindevaluator.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from torch.utils.data import DataLoader
4
+ from datasets import load_from_disk
5
+ from transformers import AutoTokenizer
6
+ from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, accuracy_score
7
+ from argparse import ArgumentParser
8
+ import os
9
+ import torch.distributed as dist
10
+ import pandas as pd
11
+ import pdb
12
+
13
+ from modules.bindevaluator_modules import * # Import your model and other necessary classes/functions here
14
+
15
+ def parse_motifs(motif: str) -> list:
16
+ parts = motif.split(',')
17
+ result = []
18
+
19
+ for part in parts:
20
+ part = part.strip()
21
+ if '-' in part:
22
+ start, end = map(int, part.split('-'))
23
+ result.extend(range(start, end + 1))
24
+ else:
25
+ result.append(int(part))
26
+
27
+ result = [pos-1 for pos in result]
28
+ print(f'Target Motifs: {result}')
29
+ return torch.tensor(result)
30
+
31
+
32
+ class PeptideModel(pl.LightningModule):
33
+ def __init__(self, n_layers, d_model, d_hidden, n_head,
34
+ d_k, d_v, d_inner, dropout=0.2,
35
+ learning_rate=0.00001, max_epochs=15, kl_weight=1):
36
+ super(PeptideModel, self).__init__()
37
+
38
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
39
+ # freeze all the esm_model parameters
40
+ for param in self.esm_model.parameters():
41
+ param.requires_grad = False
42
+
43
+ self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
44
+ n_head, d_k, d_v, d_inner, dropout=dropout)
45
+
46
+ self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
47
+ d_k, d_v, dropout=dropout)
48
+
49
+ self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
50
+
51
+ self.output_projection_prot = nn.Linear(d_model, 1)
52
+
53
+ self.learning_rate = learning_rate
54
+ self.max_epochs = max_epochs
55
+ self.kl_weight = kl_weight
56
+
57
+ self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
58
+ self.historical_memory = 0.9
59
+ self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
60
+
61
+ def forward(self, binder_tokens, target_tokens):
62
+ peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
63
+ protein_sequence = self.esm_model(**target_tokens).last_hidden_state
64
+
65
+ prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
66
+ seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
67
+ protein_sequence)
68
+
69
+ prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
70
+
71
+ prot_enc = self.final_ffn(prot_enc)
72
+
73
+ prot_enc = self.output_projection_prot(prot_enc)
74
+
75
+ return prot_enc
76
+
77
+
78
+ def calculate_score(target_sequence, binder_sequence, model, args):
79
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
80
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
81
+ anchor_tokens = tokenizer(target_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000)
82
+ positive_tokens = tokenizer(binder_sequence, return_tensors='pt', padding=True, truncation=True, max_length=40000)
83
+
84
+ anchor_tokens['attention_mask'][0][0] = 0
85
+ anchor_tokens['attention_mask'][0][-1] = 0
86
+ positive_tokens['attention_mask'][0][0] = 0
87
+ positive_tokens['attention_mask'][0][-1] = 0
88
+
89
+ target_tokens = {'input_ids': anchor_tokens["input_ids"].to(device),
90
+ 'attention_mask': anchor_tokens["attention_mask"].to(device)}
91
+ binder_tokens = {'input_ids': positive_tokens['input_ids'].to(device),
92
+ 'attention_mask': positive_tokens['attention_mask'].to(device)}
93
+
94
+ model.eval()
95
+
96
+ # pdb.set_trace()
97
+
98
+ prediction = model(binder_tokens, target_tokens).squeeze(-1)[0][1:-1]
99
+ prediction = torch.sigmoid(prediction)
100
+
101
+ return prediction, model.classification_threshold
102
+
103
+
104
+ def compute_metrics(true_residues, predicted_residues, length):
105
+ # Initialize the true and predicted lists with 0
106
+ true_list = [0] * length
107
+ predicted_list = [0] * length
108
+
109
+ # Set the values to 1 based on the provided lists
110
+ for index in true_residues:
111
+ true_list[index] = 1
112
+ for index in predicted_residues:
113
+ predicted_list[index] = 1
114
+
115
+ # Compute the metrics
116
+ accuracy = accuracy_score(true_list, predicted_list)
117
+ f1 = f1_score(true_list, predicted_list)
118
+ mcc = matthews_corrcoef(true_list, predicted_list)
119
+
120
+ return accuracy, f1, mcc
121
+
122
+
123
+ def main():
124
+ parser = ArgumentParser()
125
+ parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt',
126
+ help="File containing initial params", type=str)
127
+ parser.add_argument("-batch_size", type=int, default=32, help="Batch size")
128
+ parser.add_argument("-lr", type=float, default=1e-3)
129
+ parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
130
+ parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
131
+ parser.add_argument("-d_hidden", type=int, default=128, help="Dimension of CNN block")
132
+ parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
133
+ parser.add_argument("-d_inner", type=int, default=64)
134
+ parser.add_argument("-target", type=str)
135
+ parser.add_argument("-binder", type=str)
136
+ parser.add_argument("-gt", type=str, default=None)
137
+ parser.add_argument("-motifs", type=str, default=None)
138
+ args = parser.parse_args()
139
+
140
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
141
+
142
+ model = PeptideModel.load_from_checkpoint(args.sm,
143
+ n_layers=args.n_layers,
144
+ d_model=args.d_model,
145
+ d_hidden=args.d_hidden,
146
+ n_head=args.n_head,
147
+ d_k=64,
148
+ d_v=128,
149
+ d_inner=64).to(device)
150
+
151
+ prediction, _ = calculate_score(args.target, args.binder, model, args)
152
+ # print(prediction)
153
+ # print(model.classification_threshold)
154
+
155
+ binding_site = []
156
+ for i in range(len(prediction)):
157
+ if prediction[i] >= 0.5:
158
+ binding_site.append(i)
159
+
160
+ print("Prediction: ", binding_site)
161
+ prediction = prediction.detach().cpu().tolist()
162
+ np.set_printoptions(precision=2, suppress=True)
163
+ print(prediction)
164
+
165
+ if args.motifs is not None:
166
+ motifs = parse_motifs(args.motifs).tolist()
167
+ print(f"Motif Score: {torch.sum(prediction[motifs]) / len(motifs)}")
168
+
169
+ if args.gt is not None:
170
+ L = len(args.target)
171
+ # print(L)
172
+ gt = parse_motifs(args.gt)
173
+ print("Ground Truth: ", gt)
174
+
175
+ acc, f1, mcc = compute_metrics(gt, binding_site, L)
176
+ print(f"Accuracy={acc}\tF1={f1}\tMCC={mcc}")
177
+
178
+ # print("Prediction Logits: ", prediction[binding_site])
179
+
180
+
181
+ if __name__ == "__main__":
182
+ main()
classifier_code/__init__.py ADDED
File without changes
classifier_code/binding_affinity_unpooled.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.metrics import accuracy_score, f1_score
8
+ from scipy.stats import spearmanr
9
+ from collections import defaultdict
10
+ import pandas as pd
11
+ import logging
12
+ import os
13
+ import torch.optim as optim
14
+ from datetime import datetime
15
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
16
+ class UnpooledBindingPredictor(nn.Module):
17
+ def __init__(self,
18
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
19
+ hidden_dim=512,
20
+ kernel_sizes=[3, 5, 7],
21
+ n_heads=8,
22
+ n_layers=3,
23
+ dropout=0.1,
24
+ freeze_esm=True):
25
+ super().__init__()
26
+
27
+ # Define binding thresholds
28
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
29
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
30
+
31
+ # Load ESM model for computing embeddings on the fly
32
+ self.esm_model = AutoModel.from_pretrained(esm_model_name)
33
+ self.config = AutoConfig.from_pretrained(esm_model_name)
34
+
35
+ # Freeze ESM parameters if needed
36
+ if freeze_esm:
37
+ for param in self.esm_model.parameters():
38
+ param.requires_grad = False
39
+
40
+ # Get ESM hidden size
41
+ esm_dim = self.config.hidden_size
42
+
43
+ # Output channels for CNN layers
44
+ output_channels_per_kernel = 64
45
+
46
+ # CNN layers for handling variable length sequences
47
+ self.protein_conv_layers = nn.ModuleList([
48
+ nn.Conv1d(
49
+ in_channels=esm_dim,
50
+ out_channels=output_channels_per_kernel,
51
+ kernel_size=k,
52
+ padding='same'
53
+ ) for k in kernel_sizes
54
+ ])
55
+
56
+ self.binder_conv_layers = nn.ModuleList([
57
+ nn.Conv1d(
58
+ in_channels=esm_dim,
59
+ out_channels=output_channels_per_kernel,
60
+ kernel_size=k,
61
+ padding='same'
62
+ ) for k in kernel_sizes
63
+ ])
64
+
65
+ # Calculate total features after convolution and pooling
66
+ total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
67
+
68
+ # Project to same dimension after CNN processing
69
+ self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
70
+ self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
71
+
72
+ self.protein_norm = nn.LayerNorm(hidden_dim)
73
+ self.binder_norm = nn.LayerNorm(hidden_dim)
74
+
75
+ # Cross attention blocks with layer norm
76
+ self.cross_attention_layers = nn.ModuleList([
77
+ nn.ModuleDict({
78
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
79
+ 'norm1': nn.LayerNorm(hidden_dim),
80
+ 'ffn': nn.Sequential(
81
+ nn.Linear(hidden_dim, hidden_dim * 4),
82
+ nn.ReLU(),
83
+ nn.Dropout(dropout),
84
+ nn.Linear(hidden_dim * 4, hidden_dim)
85
+ ),
86
+ 'norm2': nn.LayerNorm(hidden_dim)
87
+ }) for _ in range(n_layers)
88
+ ])
89
+
90
+ # Prediction heads
91
+ self.shared_head = nn.Sequential(
92
+ nn.Linear(hidden_dim * 2, hidden_dim),
93
+ nn.ReLU(),
94
+ nn.Dropout(dropout),
95
+ )
96
+
97
+ # Regression head
98
+ self.regression_head = nn.Linear(hidden_dim, 1)
99
+
100
+ # Classification head (3 classes: tight, medium, loose binding)
101
+ self.classification_head = nn.Linear(hidden_dim, 3)
102
+
103
+ def get_binding_class(self, affinity):
104
+ """Convert affinity values to class indices
105
+ 0: tight binding (>= 7.5)
106
+ 1: medium binding (6.0-7.5)
107
+ 2: weak binding (< 6.0)
108
+ """
109
+ if isinstance(affinity, torch.Tensor):
110
+ tight_mask = affinity >= self.tight_threshold
111
+ weak_mask = affinity < self.weak_threshold
112
+ medium_mask = ~(tight_mask | weak_mask)
113
+
114
+ classes = torch.zeros_like(affinity, dtype=torch.long)
115
+ classes[medium_mask] = 1
116
+ classes[weak_mask] = 2
117
+ return classes
118
+ else:
119
+ if affinity >= self.tight_threshold:
120
+ return 0 # tight binding
121
+ elif affinity < self.weak_threshold:
122
+ return 2 # weak binding
123
+ else:
124
+ return 1 # medium binding
125
+
126
+ def compute_embeddings(self, input_ids, attention_mask=None):
127
+ """Compute ESM embeddings on the fly"""
128
+ esm_outputs = self.esm_model(
129
+ input_ids=input_ids,
130
+ attention_mask=attention_mask,
131
+ return_dict=True
132
+ )
133
+
134
+ # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
135
+ return esm_outputs.last_hidden_state
136
+
137
+ def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
138
+ """Process a sequence through CNN layers and pooling"""
139
+ # Transpose for CNN: [batch_size, hidden_size, seq_length]
140
+ x = unpooled_emb.transpose(1, 2)
141
+
142
+ # Apply CNN layers and collect outputs
143
+ conv_outputs = []
144
+ for conv in conv_layers:
145
+ conv_out = F.relu(conv(x))
146
+ conv_outputs.append(conv_out)
147
+
148
+ # Concatenate along channel dimension
149
+ conv_output = torch.cat(conv_outputs, dim=1)
150
+
151
+ # Global pooling (both max and average)
152
+ # If attention mask is provided, use it to create a proper mask for pooling
153
+ if attention_mask is not None:
154
+ # Create a mask for pooling (1 for valid positions, 0 for padding)
155
+ # Expand mask to match conv_output channels
156
+ expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
157
+
158
+ # Apply mask (set padding to large negative value for max pooling)
159
+ masked_output = conv_output.clone()
160
+ masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
161
+
162
+ # Max pooling along sequence dimension
163
+ max_pooled = torch.max(masked_output, dim=2)[0]
164
+
165
+ # Average pooling (sum divided by number of valid positions)
166
+ sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
167
+ valid_positions = torch.sum(expanded_mask, dim=2)
168
+ valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
169
+ avg_pooled = sum_pooled / valid_positions
170
+ else:
171
+ # If no mask, use standard pooling
172
+ max_pooled = torch.max(conv_output, dim=2)[0]
173
+ avg_pooled = torch.mean(conv_output, dim=2)
174
+
175
+ # Concatenate the pooled features
176
+ pooled = torch.cat([max_pooled, avg_pooled], dim=1)
177
+
178
+ return pooled
179
+
180
+ def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
181
+ # Compute embeddings on the fly using the ESM model
182
+ protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
183
+ binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
184
+
185
+ # Process protein and binder sequences through CNN layers
186
+ protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
187
+ binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
188
+
189
+ # Project to same dimension
190
+ protein = self.protein_norm(self.protein_projection(protein_features))
191
+ binder = self.binder_norm(self.binder_projection(binder_features))
192
+
193
+ # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
194
+ protein = protein.unsqueeze(0)
195
+ binder = binder.unsqueeze(0)
196
+
197
+ # Cross attention layers
198
+ for layer in self.cross_attention_layers:
199
+ # Protein attending to binder
200
+ attended_protein = layer['attention'](
201
+ protein, binder, binder
202
+ )[0]
203
+ protein = layer['norm1'](protein + attended_protein)
204
+ protein = layer['norm2'](protein + layer['ffn'](protein))
205
+
206
+ # Binder attending to protein
207
+ attended_binder = layer['attention'](
208
+ binder, protein, protein
209
+ )[0]
210
+ binder = layer['norm1'](binder + attended_binder)
211
+ binder = layer['norm2'](binder + layer['ffn'](binder))
212
+
213
+ # Remove sequence dimension
214
+ protein_pool = protein.squeeze(0)
215
+ binder_pool = binder.squeeze(0)
216
+
217
+ # Concatenate both representations
218
+ combined = torch.cat([protein_pool, binder_pool], dim=-1)
219
+
220
+ # Shared features
221
+ shared_features = self.shared_head(combined)
222
+
223
+ regression_output = self.regression_head(shared_features)
224
+ classification_logits = self.classification_head(shared_features)
225
+
226
+ return regression_output, classification_logits
227
+
228
+ def load_model(checkpoint_path, device):
229
+ """Load trained model from checkpoint."""
230
+ checkpoint = torch.load(checkpoint_path, map_location=device)
231
+ # Import the model class from your module or redefine it here
232
+
233
+ # Initialize model with the same parameters used during training
234
+ model = UnpooledBindingPredictor(
235
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
236
+ hidden_dim=384,
237
+ kernel_sizes=[3, 5, 7],
238
+ n_heads=8,
239
+ n_layers=4,
240
+ dropout=0.14561457009902096,
241
+ freeze_esm=True
242
+ ).to(device)
243
+
244
+ # Load the trained weights
245
+ model.load_state_dict(checkpoint['model_state_dict'])
246
+ model.eval() # Set to evaluation mode
247
+
248
+ return model
249
+
250
+
251
+ def prepare_inputs(protein_sequence, binder_sequence, tokenizer, max_length=1024, device='cuda'):
252
+ """Tokenize protein and binder sequences."""
253
+ protein_tokens = tokenizer(
254
+ protein_sequence,
255
+ return_tensors="pt",
256
+ padding="max_length",
257
+ max_length=max_length,
258
+ truncation=True
259
+ )
260
+
261
+ binder_tokens = tokenizer(
262
+ binder_sequence,
263
+ return_tensors="pt",
264
+ padding="max_length",
265
+ max_length=max_length,
266
+ truncation=True
267
+ )
268
+
269
+ return {
270
+ 'protein_input_ids': protein_tokens['input_ids'].to(device),
271
+ 'protein_attention_mask': protein_tokens['attention_mask'].to(device),
272
+ 'binder_input_ids': binder_tokens['input_ids'].to(device),
273
+ 'binder_attention_mask': binder_tokens['attention_mask'].to(device)
274
+ }
275
+
276
+ # Perform prediction
277
+ def predict_binding(model, protein_sequence, binder_sequence, device='cuda'):
278
+ """Predict binding affinity between protein and binder sequences."""
279
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
280
+ inputs = prepare_inputs(protein_sequence, binder_sequence, tokenizer, device=device)
281
+
282
+ with torch.no_grad():
283
+ regression_output, classification_logits = model(
284
+ inputs['protein_input_ids'],
285
+ inputs['binder_input_ids'],
286
+ inputs['protein_attention_mask'],
287
+ inputs['binder_attention_mask']
288
+ )
289
+
290
+ # Get numerical prediction (pKd/pKi)
291
+ predicted_affinity = regression_output.item()
292
+
293
+ # Get classification prediction (tight, medium, weak)
294
+ predicted_class_idx = torch.argmax(classification_logits, dim=1).item()
295
+ class_names = ['Tight binding', 'Medium binding', 'Weak binding']
296
+ predicted_class = class_names[predicted_class_idx]
297
+
298
+ # Get class probabilities
299
+ class_probs = F.softmax(classification_logits, dim=1).cpu().numpy()[0]
300
+
301
+ return {
302
+ 'predicted_affinity': predicted_affinity,
303
+ 'binding_class': predicted_class,
304
+ 'class_probabilities': {name: prob for name, prob in zip(class_names, class_probs)},
305
+ 'tight_threshold': model.tight_threshold, # 7.5 (≤ ~30nM)
306
+ 'weak_threshold': model.weak_threshold # 6.0 (> 1μM)
307
+ }
308
+
309
+ # Example usage
310
+ if __name__ == "__main__":
311
+ # Set device
312
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
313
+
314
+ # Load the model
315
+ model = load_model('../classifier_ckpt/binding_affinity_unpooled.pt', device)
316
+
317
+ protein_sequence = "GSHMIEPNVISVRLFKRKVGGLGFLVKERVSKPPVIISDLIRGGAAEQSGLIQAGDIILAVNDRPLVDLSYDSALEVLRGIASETHVVLILRGPEGFTTHLETTFTGDGTPKTIRVTQPLGPPTKAV"
318
+ binder_sequence = "VVKVDSV"
319
+
320
+ result = predict_binding(model, protein_sequence, binder, device)
321
+ print(f"Affinity Score: {result['predicted_affinity']}")
classifier_code/binding_affinity_unpooled_2.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.metrics import accuracy_score, f1_score
8
+ from scipy.stats import spearmanr
9
+ from collections import defaultdict
10
+ import pandas as pd
11
+ import logging
12
+ import os
13
+ import torch.optim as optim
14
+ from datetime import datetime
15
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
16
+
17
+ import os
18
+
19
+ # point HF_ENDPOINT at your mirror
20
+ # os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
21
+
22
+ class UnpooledBindingPredictor(nn.Module):
23
+ def __init__(self,
24
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
25
+ hidden_dim=512,
26
+ kernel_sizes=[3, 5, 7],
27
+ n_heads=8,
28
+ n_layers=3,
29
+ dropout=0.1,
30
+ freeze_esm=True):
31
+ super().__init__()
32
+
33
+ # Define binding thresholds
34
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
35
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
36
+
37
+ # Load ESM model for computing embeddings on the fly
38
+ self.esm_model = AutoModel.from_pretrained(esm_model_name)
39
+ self.config = AutoConfig.from_pretrained(esm_model_name)
40
+
41
+ # Freeze ESM parameters if needed
42
+ if freeze_esm:
43
+ for param in self.esm_model.parameters():
44
+ param.requires_grad = False
45
+
46
+ # Get ESM hidden size
47
+ esm_dim = self.config.hidden_size
48
+
49
+ # Output channels for CNN layers
50
+ output_channels_per_kernel = 64
51
+
52
+ # CNN layers for handling variable length sequences
53
+ self.protein_conv_layers = nn.ModuleList([
54
+ nn.Conv1d(
55
+ in_channels=esm_dim,
56
+ out_channels=output_channels_per_kernel,
57
+ kernel_size=k,
58
+ padding='same'
59
+ ) for k in kernel_sizes
60
+ ])
61
+
62
+ self.binder_conv_layers = nn.ModuleList([
63
+ nn.Conv1d(
64
+ in_channels=esm_dim,
65
+ out_channels=output_channels_per_kernel,
66
+ kernel_size=k,
67
+ padding='same'
68
+ ) for k in kernel_sizes
69
+ ])
70
+
71
+ # Calculate total features after convolution and pooling
72
+ total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
73
+
74
+ # Project to same dimension after CNN processing
75
+ self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
76
+ self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
77
+
78
+ self.protein_norm = nn.LayerNorm(hidden_dim)
79
+ self.binder_norm = nn.LayerNorm(hidden_dim)
80
+
81
+ # Cross attention blocks with layer norm
82
+ self.cross_attention_layers = nn.ModuleList([
83
+ nn.ModuleDict({
84
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
85
+ 'norm1': nn.LayerNorm(hidden_dim),
86
+ 'ffn': nn.Sequential(
87
+ nn.Linear(hidden_dim, hidden_dim * 4),
88
+ nn.ReLU(),
89
+ nn.Dropout(dropout),
90
+ nn.Linear(hidden_dim * 4, hidden_dim)
91
+ ),
92
+ 'norm2': nn.LayerNorm(hidden_dim)
93
+ }) for _ in range(n_layers)
94
+ ])
95
+
96
+ # Prediction heads
97
+ self.shared_head = nn.Sequential(
98
+ nn.Linear(hidden_dim * 2, hidden_dim),
99
+ nn.ReLU(),
100
+ nn.Dropout(dropout),
101
+ )
102
+
103
+ # Regression head
104
+ self.regression_head = nn.Linear(hidden_dim, 1)
105
+
106
+ # Classification head (3 classes: tight, medium, loose binding)
107
+ self.classification_head = nn.Linear(hidden_dim, 3)
108
+
109
+ def get_binding_class(self, affinity):
110
+ """Convert affinity values to class indices
111
+ 0: tight binding (>= 7.5)
112
+ 1: medium binding (6.0-7.5)
113
+ 2: weak binding (< 6.0)
114
+ """
115
+ if isinstance(affinity, torch.Tensor):
116
+ tight_mask = affinity >= self.tight_threshold
117
+ weak_mask = affinity < self.weak_threshold
118
+ medium_mask = ~(tight_mask | weak_mask)
119
+
120
+ classes = torch.zeros_like(affinity, dtype=torch.long)
121
+ classes[medium_mask] = 1
122
+ classes[weak_mask] = 2
123
+ return classes
124
+ else:
125
+ if affinity >= self.tight_threshold:
126
+ return 0 # tight binding
127
+ elif affinity < self.weak_threshold:
128
+ return 2 # weak binding
129
+ else:
130
+ return 1 # medium binding
131
+
132
+ def compute_embeddings(self, input_ids, attention_mask=None):
133
+ """Compute ESM embeddings on the fly"""
134
+ esm_outputs = self.esm_model(
135
+ input_ids=input_ids,
136
+ attention_mask=attention_mask,
137
+ return_dict=True
138
+ )
139
+
140
+ # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
141
+ return esm_outputs.last_hidden_state
142
+
143
+ def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
144
+ """Process a sequence through CNN layers and pooling"""
145
+ # Transpose for CNN: [batch_size, hidden_size, seq_length]
146
+ x = unpooled_emb.transpose(1, 2)
147
+
148
+ # Apply CNN layers and collect outputs
149
+ conv_outputs = []
150
+ for conv in conv_layers:
151
+ conv_out = F.relu(conv(x))
152
+ conv_outputs.append(conv_out)
153
+
154
+ # Concatenate along channel dimension
155
+ conv_output = torch.cat(conv_outputs, dim=1)
156
+
157
+ # Global pooling (both max and average)
158
+ # If attention mask is provided, use it to create a proper mask for pooling
159
+ if attention_mask is not None:
160
+ # Create a mask for pooling (1 for valid positions, 0 for padding)
161
+ # Expand mask to match conv_output channels
162
+ expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
163
+
164
+ # Apply mask (set padding to large negative value for max pooling)
165
+ masked_output = conv_output.clone()
166
+ masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
167
+
168
+ # Max pooling along sequence dimension
169
+ max_pooled = torch.max(masked_output, dim=2)[0]
170
+
171
+ # Average pooling (sum divided by number of valid positions)
172
+ sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
173
+ valid_positions = torch.sum(expanded_mask, dim=2)
174
+ valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
175
+ avg_pooled = sum_pooled / valid_positions
176
+ else:
177
+ # If no mask, use standard pooling
178
+ max_pooled = torch.max(conv_output, dim=2)[0]
179
+ avg_pooled = torch.mean(conv_output, dim=2)
180
+
181
+ # Concatenate the pooled features
182
+ pooled = torch.cat([max_pooled, avg_pooled], dim=1)
183
+
184
+ return pooled
185
+
186
+ def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
187
+ # Compute embeddings on the fly using the ESM model
188
+ protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
189
+ binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
190
+
191
+ # Process protein and binder sequences through CNN layers
192
+ protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
193
+ binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
194
+
195
+ # Project to same dimension
196
+ protein = self.protein_norm(self.protein_projection(protein_features))
197
+ binder = self.binder_norm(self.binder_projection(binder_features))
198
+
199
+ # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
200
+ protein = protein.unsqueeze(0)
201
+ binder = binder.unsqueeze(0)
202
+
203
+ # Cross attention layers
204
+ for layer in self.cross_attention_layers:
205
+ # Protein attending to binder
206
+ attended_protein = layer['attention'](
207
+ protein, binder, binder
208
+ )[0]
209
+ protein = layer['norm1'](protein + attended_protein)
210
+ protein = layer['norm2'](protein + layer['ffn'](protein))
211
+
212
+ # Binder attending to protein
213
+ attended_binder = layer['attention'](
214
+ binder, protein, protein
215
+ )[0]
216
+ binder = layer['norm1'](binder + attended_binder)
217
+ binder = layer['norm2'](binder + layer['ffn'](binder))
218
+
219
+ # Remove sequence dimension
220
+ protein_pool = protein.squeeze(0)
221
+ binder_pool = binder.squeeze(0)
222
+
223
+ # Concatenate both representations
224
+ combined = torch.cat([protein_pool, binder_pool], dim=-1)
225
+
226
+ # Shared features
227
+ shared_features = self.shared_head(combined)
228
+
229
+ regression_output = self.regression_head(shared_features)
230
+ classification_logits = self.classification_head(shared_features)
231
+
232
+ return regression_output, classification_logits
233
+
234
+ def load_model(checkpoint_path, device):
235
+ """Load trained model from checkpoint."""
236
+ checkpoint = torch.load(checkpoint_path, map_location=device)
237
+ # Import the model class from your module or redefine it here
238
+
239
+ # Initialize model with the same parameters used during training
240
+ model = UnpooledBindingPredictor(
241
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
242
+ hidden_dim=384,
243
+ kernel_sizes=[3, 5, 7],
244
+ n_heads=8,
245
+ n_layers=4,
246
+ dropout=0.14561457009902096,
247
+ freeze_esm=True
248
+ ).to(device)
249
+
250
+ # Load the trained weights
251
+ model.load_state_dict(checkpoint['model_state_dict'])
252
+ model.eval() # Set to evaluation mode
253
+
254
+ return model
255
+
256
+
257
+ def prepare_inputs(protein_sequence, binder_sequence, tokenizer, max_length=1024, device='cuda'):
258
+ """Tokenize protein and binder sequences."""
259
+ protein_tokens = tokenizer(
260
+ protein_sequence,
261
+ return_tensors="pt",
262
+ padding="max_length",
263
+ max_length=max_length,
264
+ truncation=True
265
+ )
266
+
267
+ binder_tokens = tokenizer(
268
+ binder_sequence,
269
+ return_tensors="pt",
270
+ padding="max_length",
271
+ max_length=max_length,
272
+ truncation=True
273
+ )
274
+
275
+ return {
276
+ 'protein_input_ids': protein_tokens['input_ids'].to(device),
277
+ 'protein_attention_mask': protein_tokens['attention_mask'].to(device),
278
+ 'binder_input_ids': binder_tokens['input_ids'].to(device),
279
+ 'binder_attention_mask': binder_tokens['attention_mask'].to(device)
280
+ }
281
+
282
+ # Perform prediction
283
+ def predict_binding(model, protein_sequence, binder_sequence, device='cuda'):
284
+ """Predict binding affinity between protein and binder sequences."""
285
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
286
+ inputs = prepare_inputs(protein_sequence, binder_sequence, tokenizer, device=device)
287
+
288
+ with torch.no_grad():
289
+ regression_output, classification_logits = model(
290
+ inputs['protein_input_ids'],
291
+ inputs['binder_input_ids'],
292
+ inputs['protein_attention_mask'],
293
+ inputs['binder_attention_mask']
294
+ )
295
+
296
+ # Get numerical prediction (pKd/pKi)
297
+ predicted_affinity = regression_output.item()
298
+
299
+ # Get classification prediction (tight, medium, weak)
300
+ predicted_class_idx = torch.argmax(classification_logits, dim=1).item()
301
+ class_names = ['Tight binding', 'Medium binding', 'Weak binding']
302
+ predicted_class = class_names[predicted_class_idx]
303
+
304
+ # Get class probabilities
305
+ class_probs = F.softmax(classification_logits, dim=1).cpu().numpy()[0]
306
+
307
+ return {
308
+ 'predicted_affinity': predicted_affinity,
309
+ 'binding_class': predicted_class,
310
+ 'class_probabilities': {name: prob for name, prob in zip(class_names, class_probs)},
311
+ 'tight_threshold': model.tight_threshold, # 7.5 (≤ ~30nM)
312
+ 'weak_threshold': model.weak_threshold # 6.0 (> 1μM)
313
+ }
314
+
315
+ # Example usage
316
+ if __name__ == "__main__":
317
+ # Set device
318
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
319
+
320
+ # Load the model
321
+ model = load_model('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/binding_affinity_unpooled.pt', device)
322
+
323
+ # Example protein sequences (replace with actual sequences)
324
+ binders = ['GLSKGCFGLKLDRIGSMSGLGC', 'RGLSDGFLKLKMGISGSLGC']
325
+ protein_sequence = "RNLTLAVVLPEHNLSYAWAWPRVGPAVALAVEALGRALPVDLRFVSSELEGACSEYLAPLSAVDLKLYHDPDLLLGPGCVYPAASVARFASHWRLPLLTAGAVASGFSAKNDHYRTLVRTGPSAPKLGEFVVTLHGHFNWTARAALLYLDARTDDRPHYFTIEGVFEALQGSNLSVQHQVYAREPGGPEQATHFIRANGRIVYICGPLEMLHEILLQAQRENLTNGDYVFFYLDVFGESLRAGPTRATGRPWQDNRTREQAQALREAFQTVLVITYREPPNPEYQEFQNRLLIRAREDFGVELGPSLMNLIAGCFYDGILLYAEVLNETIQEGGTREDGLRIVEKMQGRRYHGVTGLVVMDKNNDRETDFVLWAMGDLDSGDFQPAAHYSGAEKQIWWTGRPIPWVKGAPPSDNPPCAFDLDDPSCDKTPLSTLAI"
326
+
327
+ # name = "CLIC1_10_moppit"
328
+ # print(name)
329
+ # with open(f'/home/tc415/flow_matching/samples/unconditional_samples/12.txt', 'r') as f:
330
+ # binders = f.readlines()
331
+ # binders = [binder.strip() for binder in binders]
332
+ # binders = binders[:100]
333
+
334
+ # # Make prediction
335
+ affinities = []
336
+ for binder in binders:
337
+ result = predict_binding(model, protein_sequence, binder, device)
338
+ print(result['predicted_affinity'])
339
+ affinities.append(result['predicted_affinity'])
340
+
341
+ # with open('/home/tc415/flow_matching/scores/affinity/EWSFLI1_12_unconditional.txt', 'w') as f:
342
+ # for score in affinities:
343
+ # f.write(str(score) + '\n')
344
+
345
+ # print(sum(affinities) / len(affinities))
346
+
347
+ # with open(f'/home/tc415/flow_matching/scores/affinity/{name}.txt', 'w') as f:
348
+ # for score in affinities:
349
+ # f.write(str(round(score, 4)) + '\n')
350
+
351
+ # Display results
352
+ # print(f"Predicted binding affinity (pKd/pKi): {result['predicted_affinity']:.2f}")
353
+ # print(f"Binding class: {result['binding_class']}")
354
+ # print("Class probabilities:")
355
+ # for class_name, prob in result['class_probabilities'].items():
356
+ # print(f" {class_name}: {prob:.2f}")
classifier_code/half_life.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import xgboost as xgb
4
+ from transformers import EsmModel, EsmTokenizer
5
+ import torch.nn as nn
6
+ import pdb
7
+
8
+ class PeptideCNN(nn.Module):
9
+ def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
10
+ super().__init__()
11
+ self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
12
+ self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
13
+ self.fc = nn.Linear(hidden_dims[1], output_dim)
14
+ self.dropout = nn.Dropout(dropout_rate)
15
+ self.predictor = nn.Linear(output_dim, 1) # For regression/classification
16
+
17
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
18
+ self.esm_model.eval()
19
+
20
+ def forward(self, input_ids, attention_mask=None, return_features=False):
21
+ with torch.no_grad():
22
+ x = self.esm_model(input_ids, attention_mask).last_hidden_state
23
+ # pdb.set_trace()
24
+ # x shape: (B, L, input_dim)
25
+ x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
26
+ x = nn.functional.relu(self.conv1(x))
27
+ x = self.dropout(x)
28
+ x = nn.functional.relu(self.conv2(x))
29
+ x = self.dropout(x)
30
+ x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
31
+
32
+ # Global average pooling over the sequence dimension (L)
33
+ x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
34
+
35
+ features = self.fc(x) # features shape: (B, output_dim)
36
+ if return_features:
37
+ return features
38
+ return self.predictor(features) # Output shape: (B, 1)
39
+
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ input_dim = 1280
43
+ hidden_dims = [input_dim // 2, input_dim // 4]
44
+ output_dim = input_dim // 8
45
+ dropout_rate = 0.3
46
+
47
+ nn_model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
48
+ nn_model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth'))
49
+ nn_model.eval()
50
+
51
+ def predict(inputs):
52
+ with torch.no_grad():
53
+ prediction = nn_model(**inputs, return_features=False)
54
+
55
+ return prediction.item()
56
+
57
+ if __name__ == '__main__':
58
+ sequence = 'RGLSDGFLKLKMGISGSLGC'
59
+
60
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
61
+ inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
62
+
63
+ prediction = predict(inputs)
64
+ print(prediction)
65
+ print(f"Predicted half life of {sequence} is {(10**prediction):.4f} h")
classifier_code/hemolysis_wt.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
4
+ import xgboost as xgb
5
+ import torch
6
+ import numpy as np
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit import Chem, rdBase, DataStructs
10
+ from transformers import AutoTokenizer, EsmModel
11
+
12
+ rdBase.DisableLog('rdApp.error')
13
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+
17
+
18
+ class Hemolysis:
19
+ def __init__(self):
20
+ # change model path
21
+ self.predictor = xgb.Booster(model_file='/home/tc415/flow_matching/classifier_ckpt/best_model_hemolysis.json')
22
+
23
+ # Load ESM model and tokenizer
24
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
25
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
26
+ self.model.eval()
27
+
28
+ def generate_embeddings(self, sequences):
29
+ """Generate ESM embeddings for protein sequences"""
30
+ embeddings = []
31
+
32
+ # Process sequences in batches to avoid memory issues
33
+ batch_size = 8
34
+ for i in range(0, len(sequences), batch_size):
35
+ batch_sequences = sequences[i:i + batch_size]
36
+
37
+ inputs = self.tokenizer(
38
+ batch_sequences,
39
+ padding=True,
40
+ truncation=True,
41
+ return_tensors="pt"
42
+ )
43
+
44
+ if torch.cuda.is_available():
45
+ inputs = {k: v.cuda() for k, v in inputs.items()}
46
+ self.model = self.model.cuda()
47
+
48
+ # Generate embeddings
49
+ with torch.no_grad():
50
+ outputs = self.model(**inputs)
51
+
52
+ # Get last hidden states
53
+ last_hidden_states = outputs.last_hidden_state
54
+ # pdb.set_trace()
55
+ # Compute mean pooling (excluding padding tokens)
56
+ attention_mask = inputs['attention_mask'].unsqueeze(-1)
57
+ masked_hidden_states = last_hidden_states * attention_mask
58
+ sum_hidden_states = masked_hidden_states.sum(dim=1)
59
+ seq_lengths = attention_mask.sum(dim=1)
60
+ batch_embeddings = sum_hidden_states / seq_lengths
61
+
62
+ batch_embeddings = batch_embeddings.cpu().numpy()
63
+ embeddings.append(batch_embeddings)
64
+
65
+ if embeddings:
66
+ return np.vstack(embeddings)
67
+ else:
68
+ return np.array([])
69
+
70
+ def get_scores(self, input_seqs: list):
71
+ scores = np.ones(len(input_seqs))
72
+ features = self.generate_embeddings(input_seqs)
73
+
74
+ if len(features) == 0:
75
+ return scores
76
+
77
+ features = np.nan_to_num(features, nan=0.)
78
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
79
+
80
+ features = xgb.DMatrix(features)
81
+
82
+ probs = self.predictor.predict(features)
83
+ # return the probability of it being not hemolytic
84
+ return scores - probs
85
+
86
+ def __call__(self, input_seqs: list):
87
+ scores = self.get_scores(input_seqs)
88
+ return scores
89
+
90
+ def unittest():
91
+ hemolysis = Hemolysis()
92
+ sequences = [
93
+ "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
94
+ "MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD"
95
+ ]
96
+
97
+ scores = hemolysis(input_seqs=sequences)
98
+ print([1-score for score in scores])
99
+
100
+ if __name__ == '__main__':
101
+ unittest()
classifier_code/nonfouling_wt.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ import warnings
7
+ import numpy as np
8
+ from rdkit import Chem, rdBase, DataStructs
9
+ from transformers import AutoTokenizer, EsmModel
10
+
11
+ rdBase.DisableLog('rdApp.error')
12
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+ warnings.filterwarnings("ignore", category=FutureWarning)
15
+
16
+ class Nonfouling:
17
+ def __init__(self):
18
+ # change model path
19
+ self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_nonfouling.json')
20
+
21
+ # Load ESM model and tokenizer
22
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
23
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
24
+ self.model.eval()
25
+
26
+ def generate_embeddings(self, sequences):
27
+ """Generate ESM embeddings for protein sequences"""
28
+ embeddings = []
29
+
30
+ # Process sequences in batches to avoid memory issues
31
+ batch_size = 8
32
+ for i in range(0, len(sequences), batch_size):
33
+ batch_sequences = sequences[i:i + batch_size]
34
+
35
+ inputs = self.tokenizer(
36
+ batch_sequences,
37
+ padding=True,
38
+ truncation=True,
39
+ return_tensors="pt"
40
+ )
41
+
42
+ if torch.cuda.is_available():
43
+ inputs = {k: v.cuda() for k, v in inputs.items()}
44
+ self.model = self.model.cuda()
45
+
46
+ # Generate embeddings
47
+ with torch.no_grad():
48
+ outputs = self.model(**inputs)
49
+
50
+ # Get last hidden states
51
+ last_hidden_states = outputs.last_hidden_state
52
+
53
+ # Compute mean pooling (excluding padding tokens)
54
+ attention_mask = inputs['attention_mask'].unsqueeze(-1)
55
+ masked_hidden_states = last_hidden_states * attention_mask
56
+ sum_hidden_states = masked_hidden_states.sum(dim=1)
57
+ seq_lengths = attention_mask.sum(dim=1)
58
+ batch_embeddings = sum_hidden_states / seq_lengths
59
+
60
+ batch_embeddings = batch_embeddings.cpu().numpy()
61
+ embeddings.append(batch_embeddings)
62
+
63
+ if embeddings:
64
+ return np.vstack(embeddings)
65
+ else:
66
+ return np.array([])
67
+
68
+ def get_scores(self, input_seqs: list):
69
+ scores = np.zeros(len(input_seqs))
70
+ features = self.generate_embeddings(input_seqs)
71
+
72
+ if len(features) == 0:
73
+ return scores
74
+
75
+ features = np.nan_to_num(features, nan=0.)
76
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
77
+
78
+ features = xgb.DMatrix(features)
79
+
80
+ scores = self.predictor.predict(features)
81
+ return scores
82
+
83
+ def __call__(self, input_seqs: list):
84
+ scores = self.get_scores(input_seqs)
85
+ return scores
86
+
87
+ def unittest():
88
+ nonfouling = Nonfouling()
89
+ sequences = [
90
+ "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
91
+ "MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD"
92
+ ]
93
+
94
+ scores = nonfouling(input_seqs=sequences)
95
+ print(scores)
96
+
97
+ if __name__ == '__main__':
98
+ unittest()
classifier_code/solubility_wt.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ import warnings
7
+ import numpy as np
8
+ from rdkit import Chem, rdBase, DataStructs
9
+ from transformers import AutoTokenizer, EsmModel
10
+
11
+ rdBase.DisableLog('rdApp.error')
12
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+ warnings.filterwarnings("ignore", category=FutureWarning)
15
+
16
+ class Solubility:
17
+ def __init__(self):
18
+ # change model path
19
+ self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
20
+
21
+ # Load ESM model and tokenizer
22
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
23
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
24
+ self.model.eval()
25
+
26
+ def generate_embeddings(self, sequences):
27
+ """Generate ESM embeddings for protein sequences"""
28
+ embeddings = []
29
+
30
+ # Process sequences in batches to avoid memory issues
31
+ batch_size = 8
32
+ for i in range(0, len(sequences), batch_size):
33
+ batch_sequences = sequences[i:i + batch_size]
34
+
35
+ inputs = self.tokenizer(
36
+ batch_sequences,
37
+ padding=True,
38
+ truncation=True,
39
+ return_tensors="pt"
40
+ )
41
+
42
+ if torch.cuda.is_available():
43
+ inputs = {k: v.cuda() for k, v in inputs.items()}
44
+ self.model = self.model.cuda()
45
+
46
+ # Generate embeddings
47
+ with torch.no_grad():
48
+ outputs = self.model(**inputs)
49
+
50
+ # Get last hidden states
51
+ last_hidden_states = outputs.last_hidden_state
52
+
53
+ # Compute mean pooling (excluding padding tokens)
54
+ attention_mask = inputs['attention_mask'].unsqueeze(-1)
55
+ masked_hidden_states = last_hidden_states * attention_mask
56
+ sum_hidden_states = masked_hidden_states.sum(dim=1)
57
+ seq_lengths = attention_mask.sum(dim=1)
58
+ batch_embeddings = sum_hidden_states / seq_lengths
59
+
60
+ batch_embeddings = batch_embeddings.cpu().numpy()
61
+ embeddings.append(batch_embeddings)
62
+
63
+ if embeddings:
64
+ return np.vstack(embeddings)
65
+ else:
66
+ return np.array([])
67
+
68
+ def get_scores(self, input_seqs: list):
69
+ scores = np.zeros(len(input_seqs))
70
+ features = self.generate_embeddings(input_seqs)
71
+
72
+ if len(features) == 0:
73
+ return scores
74
+
75
+ features = np.nan_to_num(features, nan=0.)
76
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
77
+
78
+ features = xgb.DMatrix(features)
79
+
80
+ scores = self.predictor.predict(features)
81
+ return scores
82
+
83
+ def __call__(self, input_seqs: list):
84
+ scores = self.get_scores(input_seqs)
85
+ return scores
86
+
87
+ def unittest():
88
+ solubility = Solubility()
89
+ sequences = [
90
+ "GLSKGCFGLKLDRIGSMSGLGC",
91
+ "RGLSDGFLKLKMGISGSLGC"
92
+ ]
93
+
94
+ scores = solubility(input_seqs=sequences)
95
+ print(scores)
96
+
97
+ if __name__ == '__main__':
98
+ unittest()
flow_matching/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ __version__ = "1.0.10"
flow_matching/loss/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .generalized_loss import MixturePathGeneralizedKL
8
+
9
+ __all__ = [
10
+ "MixturePathGeneralizedKL",
11
+ ]
flow_matching/loss/generalized_loss.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.nn.modules.loss import _Loss
10
+
11
+ from flow_matching.path import MixtureDiscreteProbPath
12
+
13
+
14
+ class MixturePathGeneralizedKL(_Loss):
15
+ r"""A generalized KL loss for discrete flow matching.
16
+ A class that measures the generalized KL of a discrete flow model :math:`p_{1|t}` w.r.t. a probability path given by ``path``. Note: this class is assuming that the model is trained on the same path.
17
+
18
+ For a model trained on a space :math:`\mathcal{S} = \mathcal{T}^d`, :math:`\mathcal{T} = [K] = \set{1,2,\ldots,K}`, the loss is given by
19
+
20
+ .. math::
21
+ \ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr],
22
+
23
+ where :math:`\kappa_t` is the scheduler associated with ``path``.
24
+
25
+ Args:
26
+ path (MixtureDiscreteProbPath): Probability path (x-prediction training).
27
+ reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'.
28
+ """
29
+
30
+ def __init__(self, path: MixtureDiscreteProbPath, reduction: str = "mean") -> None:
31
+ super().__init__(None, None, reduction)
32
+ self.path = path
33
+
34
+ def forward(self, logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
35
+ r"""Evaluates the generalized KL loss.
36
+
37
+ Args:
38
+ logits (Tensor): posterior model output (i.e., softmax(``logits``) :math:`=p_{1|t}(x|x_t)`), shape (batch, d, K).
39
+ x_1 (Tensor): target data point :math:`x_1 \sim q`, shape (batch, d).
40
+ x_t (Tensor): conditional sample at :math:`x_t \sim p_t(\cdot|x_1)`, shape (batch, d).
41
+ t (Tensor): times in :math:`[0,1]`, shape (batch).
42
+
43
+ Raises:
44
+ ValueError: reduction value must be one of ``'none'`` | ``'mean'`` | ``'sum'``.
45
+
46
+ Returns:
47
+ Tensor: Generalized KL loss.
48
+ """
49
+ x_1_shape = x_1.shape
50
+
51
+ # extract x_1 value of log(p_{1|t}(x|x_t)).
52
+ log_p_1t = torch.log_softmax(logits, dim=-1)
53
+ log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1))
54
+ log_p_1t_x1 = log_p_1t_x1.view(*x_1_shape)
55
+
56
+ # extract x_t value of p_{1|t}(x|x_t).
57
+ p_1t = torch.exp(log_p_1t)
58
+ p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1))
59
+ p_1t_xt = p_1t_xt.view(*x_1_shape)
60
+
61
+ scheduler_output = self.path.scheduler(t)
62
+
63
+ jump_coefficient = (
64
+ scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t)
65
+ )[(...,) + (None,) * (x_1.dim() - 1)]
66
+ jump_coefficient = jump_coefficient.repeat(1, *x_1_shape[1:])
67
+ delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype)
68
+
69
+ loss = -jump_coefficient * (
70
+ p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1
71
+ )
72
+
73
+ mask = (x_1 != 1).to(loss.dtype) # 1 is the masked token
74
+ loss = loss * mask
75
+
76
+ if self.reduction == "mean":
77
+ return torch.mean(loss)
78
+ elif self.reduction == "sum":
79
+ return torch.sum(loss)
80
+ elif self.reduction == "none":
81
+ return loss
82
+ else:
83
+ raise ValueError(f"{self.reduction} is not a valid value for reduction")
flow_matching/path/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .affine import AffineProbPath, CondOTProbPath
8
+ from .geodesic import GeodesicProbPath
9
+ from .mixture import MixtureDiscreteProbPath
10
+ from .path import ProbPath
11
+ from .path_sample import DiscretePathSample, PathSample
12
+
13
+
14
+ __all__ = [
15
+ "ProbPath",
16
+ "AffineProbPath",
17
+ "CondOTProbPath",
18
+ "MixtureDiscreteProbPath",
19
+ "GeodesicProbPath",
20
+ "PathSample",
21
+ "DiscretePathSample",
22
+ ]
flow_matching/path/affine.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import Tensor
8
+
9
+ from flow_matching.path.path import ProbPath
10
+ from flow_matching.path.path_sample import PathSample
11
+ from flow_matching.path.scheduler.scheduler import CondOTScheduler, Scheduler
12
+ from flow_matching.utils import expand_tensor_like
13
+
14
+
15
+ class AffineProbPath(ProbPath):
16
+ r"""The ``AffineProbPath`` class represents a specific type of probability path where the transformation between distributions is affine.
17
+ An affine transformation can be represented as:
18
+
19
+ .. math::
20
+
21
+ X_t = \alpha_t X_1 + \sigma_t X_0,
22
+
23
+ where :math:`X_t` is the transformed data point at time `t`. :math:`X_0` and :math:`X_1` are the source and target data points, respectively. :math:`\alpha_t` and :math:`\sigma_t` are the parameters of the affine transformation at time `t`.
24
+
25
+ The scheduler is responsible for providing the time-dependent parameters :math:`\alpha_t` and :math:`\sigma_t`, as well as their derivatives, which define the affine transformation at any given time `t`.
26
+
27
+ Using ``AffineProbPath`` in the flow matching framework:
28
+
29
+ .. code-block:: python
30
+
31
+ # Instantiates a probability path
32
+ my_path = AffineProbPath(...)
33
+ mse_loss = torch.nn.MSELoss()
34
+
35
+ for x_1 in dataset:
36
+ # Sets x_0 to random noise
37
+ x_0 = torch.randn()
38
+
39
+ # Sets t to a random value in [0,1]
40
+ t = torch.rand()
41
+
42
+ # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
43
+ path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
44
+
45
+ # Computes the MSE loss w.r.t. the velocity
46
+ loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
47
+ loss.backward()
48
+
49
+ Args:
50
+ scheduler (Scheduler): An instance of a scheduler that provides the parameters :math:`\alpha_t`, :math:`\sigma_t`, and their derivatives over time.
51
+
52
+ """
53
+
54
+ def __init__(self, scheduler: Scheduler):
55
+ self.scheduler = scheduler
56
+
57
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
58
+ r"""Sample from the affine probability path:
59
+
60
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
61
+ | return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`.
62
+
63
+ Args:
64
+ x_0 (Tensor): source data point, shape (batch_size, ...).
65
+ x_1 (Tensor): target data point, shape (batch_size, ...).
66
+ t (Tensor): times in [0,1], shape (batch_size).
67
+
68
+ Returns:
69
+ PathSample: a conditional sample at :math:`X_t \sim p_t`.
70
+ """
71
+ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
72
+
73
+ scheduler_output = self.scheduler(t)
74
+
75
+ alpha_t = expand_tensor_like(
76
+ input_tensor=scheduler_output.alpha_t, expand_to=x_1
77
+ )
78
+ sigma_t = expand_tensor_like(
79
+ input_tensor=scheduler_output.sigma_t, expand_to=x_1
80
+ )
81
+ d_alpha_t = expand_tensor_like(
82
+ input_tensor=scheduler_output.d_alpha_t, expand_to=x_1
83
+ )
84
+ d_sigma_t = expand_tensor_like(
85
+ input_tensor=scheduler_output.d_sigma_t, expand_to=x_1
86
+ )
87
+
88
+ # construct xt ~ p_t(x|x1).
89
+ x_t = sigma_t * x_0 + alpha_t * x_1
90
+ dx_t = d_sigma_t * x_0 + d_alpha_t * x_1
91
+
92
+ return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)
93
+
94
+ def target_to_velocity(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
95
+ r"""Convert from x_1 representation to velocity.
96
+
97
+ | given :math:`X_1`.
98
+ | return :math:`\dot{X}_t`.
99
+
100
+ Args:
101
+ x_1 (Tensor): target data point.
102
+ x_t (Tensor): path sample at time t.
103
+ t (Tensor): time in [0,1].
104
+
105
+ Returns:
106
+ Tensor: velocity.
107
+ """
108
+ scheduler_output = self.scheduler(t)
109
+
110
+ alpha_t = scheduler_output.alpha_t
111
+ d_alpha_t = scheduler_output.d_alpha_t
112
+ sigma_t = scheduler_output.sigma_t
113
+ d_sigma_t = scheduler_output.d_sigma_t
114
+
115
+ a_t = d_sigma_t / sigma_t
116
+ b_t = (d_alpha_t * sigma_t - d_sigma_t * alpha_t) / sigma_t
117
+
118
+ return a_t * x_t + b_t * x_1
119
+
120
+ def epsilon_to_velocity(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
121
+ r"""Convert from epsilon representation to velocity.
122
+
123
+ | given :math:`\epsilon`.
124
+ | return :math:`\dot{X}_t`.
125
+
126
+ Args:
127
+ epsilon (Tensor): noise in the path sample.
128
+ x_t (Tensor): path sample at time t.
129
+ t (Tensor): time in [0,1].
130
+
131
+ Returns:
132
+ Tensor: velocity.
133
+ """
134
+ scheduler_output = self.scheduler(t)
135
+
136
+ alpha_t = scheduler_output.alpha_t
137
+ d_alpha_t = scheduler_output.d_alpha_t
138
+ sigma_t = scheduler_output.sigma_t
139
+ d_sigma_t = scheduler_output.d_sigma_t
140
+
141
+ a_t = d_alpha_t / alpha_t
142
+ b_t = (d_sigma_t * alpha_t - d_alpha_t * sigma_t) / alpha_t
143
+
144
+ return a_t * x_t + b_t * epsilon
145
+
146
+ def velocity_to_target(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
147
+ r"""Convert from velocity to x_1 representation.
148
+
149
+ | given :math:`\dot{X}_t`.
150
+ | return :math:`X_1`.
151
+
152
+ Args:
153
+ velocity (Tensor): velocity at the path sample.
154
+ x_t (Tensor): path sample at time t.
155
+ t (Tensor): time in [0,1].
156
+
157
+ Returns:
158
+ Tensor: target data point.
159
+ """
160
+ scheduler_output = self.scheduler(t)
161
+
162
+ alpha_t = scheduler_output.alpha_t
163
+ d_alpha_t = scheduler_output.d_alpha_t
164
+ sigma_t = scheduler_output.sigma_t
165
+ d_sigma_t = scheduler_output.d_sigma_t
166
+
167
+ a_t = -d_sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t)
168
+ b_t = sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t)
169
+
170
+ return a_t * x_t + b_t * velocity
171
+
172
+ def epsilon_to_target(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
173
+ r"""Convert from epsilon representation to x_1 representation.
174
+
175
+ | given :math:`\epsilon`.
176
+ | return :math:`X_1`.
177
+
178
+ Args:
179
+ epsilon (Tensor): noise in the path sample.
180
+ x_t (Tensor): path sample at time t.
181
+ t (Tensor): time in [0,1].
182
+
183
+ Returns:
184
+ Tensor: target data point.
185
+ """
186
+ scheduler_output = self.scheduler(t)
187
+
188
+ alpha_t = scheduler_output.alpha_t
189
+ sigma_t = scheduler_output.sigma_t
190
+
191
+ a_t = 1 / alpha_t
192
+ b_t = -sigma_t / alpha_t
193
+
194
+ return a_t * x_t + b_t * epsilon
195
+
196
+ def velocity_to_epsilon(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
197
+ r"""Convert from velocity to noise representation.
198
+
199
+ | given :math:`\dot{X}_t`.
200
+ | return :math:`\epsilon`.
201
+
202
+ Args:
203
+ velocity (Tensor): velocity at the path sample.
204
+ x_t (Tensor): path sample at time t.
205
+ t (Tensor): time in [0,1].
206
+
207
+ Returns:
208
+ Tensor: noise in the path sample.
209
+ """
210
+ scheduler_output = self.scheduler(t)
211
+
212
+ alpha_t = scheduler_output.alpha_t
213
+ d_alpha_t = scheduler_output.d_alpha_t
214
+ sigma_t = scheduler_output.sigma_t
215
+ d_sigma_t = scheduler_output.d_sigma_t
216
+
217
+ a_t = -d_alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t)
218
+ b_t = alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t)
219
+
220
+ return a_t * x_t + b_t * velocity
221
+
222
+ def target_to_epsilon(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
223
+ r"""Convert from x_1 representation to velocity.
224
+
225
+ | given :math:`X_1`.
226
+ | return :math:`\epsilon`.
227
+
228
+ Args:
229
+ x_1 (Tensor): target data point.
230
+ x_t (Tensor): path sample at time t.
231
+ t (Tensor): time in [0,1].
232
+
233
+ Returns:
234
+ Tensor: noise in the path sample.
235
+ """
236
+ scheduler_output = self.scheduler(t)
237
+
238
+ alpha_t = scheduler_output.alpha_t
239
+ sigma_t = scheduler_output.sigma_t
240
+
241
+ a_t = 1 / sigma_t
242
+ b_t = -alpha_t / sigma_t
243
+
244
+ return a_t * x_t + b_t * x_1
245
+
246
+
247
+ class CondOTProbPath(AffineProbPath):
248
+ r"""The ``CondOTProbPath`` class represents a conditional optimal transport probability path.
249
+
250
+ This class is a specialized version of the ``AffineProbPath`` that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.
251
+
252
+ The parameters :math:`\alpha_t` and :math:`\sigma_t` for the conditional optimal transport path are defined as:
253
+
254
+ .. math::
255
+
256
+ \alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.
257
+ """
258
+
259
+ def __init__(self):
260
+ self.scheduler = CondOTScheduler()
flow_matching/path/geodesic.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from torch import Tensor
10
+ from torch.func import jvp, vmap
11
+
12
+ from flow_matching.path.path import ProbPath
13
+
14
+ from flow_matching.path.path_sample import PathSample
15
+ from flow_matching.path.scheduler import ConvexScheduler
16
+ from flow_matching.utils import expand_tensor_like
17
+
18
+ from flow_matching.utils.manifolds import geodesic, Manifold
19
+
20
+
21
+ class GeodesicProbPath(ProbPath):
22
+ r"""The ``GeodesicProbPath`` class represents a specific type of probability path where the transformation between distributions is defined through the geodesic path.
23
+ Mathematically, a geodesic path can be represented as:
24
+
25
+ .. math::
26
+
27
+ X_t = \psi_t(X_0 | X_1) = \exp_{X_1}(\kappa_t \log_{X_1}(X_0)),
28
+
29
+ where :math:`X_t` is the transformed data point at time `t`, :math:`X_0` and :math:`X_1` are the source and target data points, respectively, and :math:`\kappa_t` is a scheduler.
30
+
31
+ The scheduler is responsible for providing the time-dependent :math:`\kappa_t` and must be differentiable.
32
+
33
+ Using ``GeodesicProbPath`` in the flow matching framework:
34
+
35
+ .. code-block:: python
36
+ # Instantiates a manifold
37
+ manifold = FlatTorus()
38
+
39
+ # Instantiates a scheduler
40
+ scheduler = CondOTScheduler()
41
+
42
+ # Instantiates a probability path
43
+ my_path = GeodesicProbPath(scheduler, manifold)
44
+ mse_loss = torch.nn.MSELoss()
45
+
46
+ for x_1 in dataset:
47
+ # Sets x_0 to random noise
48
+ x_0 = torch.randn()
49
+
50
+ # Sets t to a random value in [0,1]
51
+ t = torch.rand()
52
+
53
+ # Samples the conditional path :math:`X_t \sim p_t(X_t|X_0,X_1)`
54
+ path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
55
+
56
+ # Computes the MSE loss w.r.t. the velocity
57
+ loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
58
+ loss.backward()
59
+
60
+ Args:
61
+ scheduler (ConvexScheduler): The scheduler that provides :math:`\kappa_t`.
62
+ manifold (Manifold): The manifold on which the probability path is defined.
63
+
64
+ """
65
+
66
+ def __init__(self, scheduler: ConvexScheduler, manifold: Manifold):
67
+ self.scheduler = scheduler
68
+ self.manifold = manifold
69
+
70
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
71
+ r"""Sample from the Riemannian probability path with geodesic interpolation:
72
+
73
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`\kappa_t`.
74
+ | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`.
75
+
76
+ Args:
77
+ x_0 (Tensor): source data point, shape (batch_size, ...).
78
+ x_1 (Tensor): target data point, shape (batch_size, ...).
79
+ t (Tensor): times in [0,1], shape (batch_size).
80
+
81
+ Returns:
82
+ PathSample: A conditional sample at :math:`X_t \sim p_t`.
83
+ """
84
+ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
85
+ t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone()
86
+
87
+ def cond_u(x_0, x_1, t):
88
+ path = geodesic(self.manifold, x_0, x_1)
89
+ x_t, dx_t = jvp(
90
+ lambda t: path(self.scheduler(t).alpha_t),
91
+ (t,),
92
+ (torch.ones_like(t).to(t),),
93
+ )
94
+ return x_t, dx_t
95
+
96
+ x_t, dx_t = vmap(cond_u)(x_0, x_1, t)
97
+ x_t = x_t.reshape_as(x_1)
98
+ dx_t = dx_t.reshape_as(x_1)
99
+
100
+ return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)
flow_matching/path/mixture.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from torch import Tensor
11
+
12
+ from flow_matching.path.path import ProbPath
13
+
14
+ from flow_matching.path.path_sample import DiscretePathSample
15
+ from flow_matching.path.scheduler import ConvexScheduler
16
+ from flow_matching.utils import expand_tensor_like, unsqueeze_to_match
17
+
18
+
19
+ class MixtureDiscreteProbPath(ProbPath):
20
+ r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path.
21
+
22
+ This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`.
23
+ The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`:
24
+
25
+ .. math::
26
+
27
+ P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t,
28
+
29
+ where :math:`\sigma_t` is provided by the scheduler.
30
+
31
+ Example:
32
+
33
+ .. code-block:: python
34
+
35
+ >>> x_0 = torch.zeros((1, 3, 3))
36
+ >>> x_1 = torch.ones((1, 3, 3))
37
+
38
+ >>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0))
39
+ >>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t
40
+ >>> result
41
+ tensor([[[0.0, 0.0, 0.0],
42
+ [0.0, 0.0, 1.0],
43
+ [0.0, 0.0, 0.0]]])
44
+
45
+ >>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t
46
+ >>> result
47
+ tensor([[[1.0, 0.0, 1.0],
48
+ [0.0, 1.0, 0.0],
49
+ [0.0, 1.0, 0.0]]])
50
+
51
+ >>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t
52
+ >>> result
53
+ tensor([[[1.0, 1.0, 1.0],
54
+ [1.0, 1.0, 1.0],
55
+ [1.0, 1.0, 1.0]]])
56
+
57
+ Args:
58
+ scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`.
59
+ """
60
+
61
+ def __init__(self, scheduler: ConvexScheduler):
62
+ assert isinstance(
63
+ scheduler, ConvexScheduler
64
+ ), "Scheduler for ConvexProbPath must be a ConvexScheduler."
65
+
66
+ self.scheduler = scheduler
67
+
68
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
69
+ r"""Sample from the affine probability path:
70
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
71
+ | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`.
72
+ Args:
73
+ x_0 (Tensor): source data point, shape (batch_size, ...).
74
+ x_1 (Tensor): target data point, shape (batch_size, ...).
75
+ t (Tensor): times in [0,1], shape (batch_size).
76
+
77
+ Returns:
78
+ DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`.
79
+ """
80
+ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
81
+
82
+ sigma_t = self.scheduler(t).sigma_t
83
+
84
+ sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)
85
+
86
+ source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
87
+ x_t = torch.where(condition=source_indices, input=x_0, other=x_1)
88
+
89
+ return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t)
90
+
91
+ def posterior_to_velocity(
92
+ self, posterior_logits: Tensor, x_t: Tensor, t: Tensor
93
+ ) -> Tensor:
94
+ r"""Convert the factorized posterior to velocity.
95
+
96
+ | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`.
97
+ | return :math:`u_t`.
98
+
99
+ Args:
100
+ posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size).
101
+ x_t (Tensor): path sample at time t, shape (...).
102
+ t (Tensor): time in [0,1].
103
+
104
+ Returns:
105
+ Tensor: velocity.
106
+ """
107
+ posterior = torch.softmax(posterior_logits, dim=-1)
108
+ vocabulary_size = posterior.shape[-1]
109
+ x_t = F.one_hot(x_t, num_classes=vocabulary_size)
110
+ t = unsqueeze_to_match(source=t, target=x_t)
111
+
112
+ scheduler_output = self.scheduler(t)
113
+
114
+ kappa_t = scheduler_output.alpha_t
115
+ d_kappa_t = scheduler_output.d_alpha_t
116
+
117
+ return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t)
flow_matching/path/path.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ from torch import Tensor
10
+
11
+ from flow_matching.path.path_sample import PathSample
12
+
13
+
14
+ class ProbPath(ABC):
15
+ r"""Abstract class, representing a probability path.
16
+
17
+ A probability path transforms the distribution :math:`p(X_0)` into :math:`p(X_1)` over :math:`t=0\rightarrow 1`.
18
+
19
+ The ``ProbPath`` class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives.
20
+ Here is a high-level example
21
+
22
+ .. code-block:: python
23
+
24
+ # Instantiate a probability path
25
+ my_path = ProbPath(...)
26
+
27
+ for x_0, x_1 in dataset:
28
+ # Sets t to a random value in [0,1]
29
+ t = torch.rand()
30
+
31
+ # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
32
+ path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
33
+
34
+ # Optimizes the model. The loss function varies, depending on model and path.
35
+ loss(path_sample, my_model(x_t, t)).backward()
36
+
37
+ """
38
+
39
+ @abstractmethod
40
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
41
+ r"""Sample from an abstract probability path:
42
+
43
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)`.
44
+ | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``.
45
+
46
+ Args:
47
+ x_0 (Tensor): source data point, shape (batch_size, ...).
48
+ x_1 (Tensor): target data point, shape (batch_size, ...).
49
+ t (Tensor): times in [0,1], shape (batch_size).
50
+
51
+ Returns:
52
+ PathSample: a conditional sample.
53
+ """
54
+
55
+ def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor):
56
+ assert (
57
+ t.ndim == 1
58
+ ), f"The time vector t must have shape [batch_size]. Got {t.shape}."
59
+ assert (
60
+ t.shape[0] == x_0.shape[0] == x_1.shape[0]
61
+ ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}"
flow_matching/path/path_sample.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ from torch import Tensor
10
+
11
+
12
+ @dataclass
13
+ class PathSample:
14
+ r"""Represents a sample of a conditional-flow generated probability path.
15
+
16
+ Attributes:
17
+ x_1 (Tensor): the target sample :math:`X_1`.
18
+ x_0 (Tensor): the source sample :math:`X_0`.
19
+ t (Tensor): the time sample :math:`t`.
20
+ x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...).
21
+ dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...).
22
+
23
+ """
24
+
25
+ x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
26
+ x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
27
+ t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
28
+ x_t: Tensor = field(
29
+ metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."}
30
+ )
31
+ dx_t: Tensor = field(
32
+ metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."}
33
+ )
34
+
35
+
36
+ @dataclass
37
+ class DiscretePathSample:
38
+ """
39
+ Represents a sample of a conditional-flow generated discrete probability path.
40
+
41
+ Attributes:
42
+ x_1 (Tensor): the target sample :math:`X_1`.
43
+ x_0 (Tensor): the source sample :math:`X_0`.
44
+ t (Tensor): the time sample :math:`t`.
45
+ x_t (Tensor): the sample along the path :math:`X_t \sim p_t`.
46
+ """
47
+
48
+ x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
49
+ x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
50
+ t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
51
+ x_t: Tensor = field(
52
+ metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."}
53
+ )
flow_matching/path/scheduler/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .schedule_transform import ScheduleTransformedModel
8
+ from .scheduler import (
9
+ CondOTScheduler,
10
+ ConvexScheduler,
11
+ CosineScheduler,
12
+ LinearVPScheduler,
13
+ PolynomialConvexScheduler,
14
+ Scheduler,
15
+ SchedulerOutput,
16
+ VPScheduler,
17
+ )
18
+
19
+ __all__ = [
20
+ "CondOTScheduler",
21
+ "CosineScheduler",
22
+ "ConvexScheduler",
23
+ "PolynomialConvexScheduler",
24
+ "ScheduleTransformedModel",
25
+ "Scheduler",
26
+ "VPScheduler",
27
+ "LinearVPScheduler",
28
+ "SchedulerOutput",
29
+ ]
flow_matching/path/scheduler/schedule_transform.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import Tensor
8
+
9
+ from flow_matching.path.scheduler.scheduler import Scheduler
10
+ from flow_matching.utils import ModelWrapper
11
+
12
+
13
+ class ScheduleTransformedModel(ModelWrapper):
14
+ """
15
+ Change of scheduler for a velocity model.
16
+
17
+ This class wraps a given velocity model and transforms its scheduling
18
+ to a new scheduler function. It modifies the time
19
+ dynamics of the model according to the new scheduler while maintaining
20
+ the original model's behavior.
21
+
22
+ Example:
23
+
24
+ .. code-block:: python
25
+
26
+ import torch
27
+ from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
28
+ from flow_matching.solver import ODESolver
29
+
30
+ # Initialize the model and schedulers
31
+ model = ...
32
+
33
+ original_scheduler = CondOTScheduler()
34
+ new_scheduler = CosineScheduler()
35
+
36
+ # Create the transformed model
37
+ transformed_model = ScheduleTransformedModel(
38
+ velocity_model=model,
39
+ original_scheduler=original_scheduler,
40
+ new_scheduler=new_scheduler
41
+ )
42
+
43
+ # Set up the solver
44
+ solver = ODESolver(velocity_model=transformed_model)
45
+
46
+ x_0 = torch.randn([10, 2]) # Example initial condition
47
+
48
+ x_1 = solver.sample(
49
+ time_steps=torch.tensor([0.0, 1.0]),
50
+ x_init=x_0,
51
+ step_size=1/1000
52
+ )[1]
53
+
54
+ Args:
55
+ velocity_model (ModelWrapper): The original velocity model to be transformed.
56
+ original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function.
57
+ new_scheduler (Scheduler): The new scheduler to be applied to the model.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ velocity_model: ModelWrapper,
63
+ original_scheduler: Scheduler,
64
+ new_scheduler: Scheduler,
65
+ ):
66
+ super().__init__(model=velocity_model)
67
+ self.original_scheduler = original_scheduler
68
+ self.new_scheduler = new_scheduler
69
+
70
+ assert hasattr(self.original_scheduler, "snr_inverse") and callable(
71
+ getattr(self.original_scheduler, "snr_inverse")
72
+ ), "The original scheduler must have a callable 'snr_inverse' method."
73
+
74
+ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
75
+ r"""
76
+ Compute the transformed marginal velocity field for a new scheduler.
77
+ This method implements a post-training velocity scheduler change for
78
+ affine conditional flows. It transforms a generating marginal velocity
79
+ field :math:`u_t(x)` based on an original scheduler to a new marginal velocity
80
+ field :math:`\bar{u}_r(x)` based on a different scheduler, while maintaining
81
+ the same data coupling.
82
+ The transformation is based on the scale-time (ST) transformation
83
+ between the two conditional flows, defined as:
84
+
85
+ .. math::
86
+
87
+ \bar{X}_r = s_r X_{t_r},
88
+
89
+ where :math:`X_t` and :math:`\bar{X}_r` are defined by their respective schedulers.
90
+ The ST transformation is computed as:
91
+
92
+ .. math::
93
+
94
+ t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}.
95
+
96
+ Here, :math:`\rho(t)` is the signal-to-noise ratio (SNR) defined as:
97
+
98
+ .. math::
99
+
100
+ \rho(t) = \frac{\alpha_t}{\sigma_t}.
101
+
102
+ :math:`\bar{\rho}(r)` is similarly defined for the new scheduler.
103
+ The marginal velocity for the new scheduler is then given by:
104
+
105
+ .. math::
106
+
107
+ \bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right).
108
+
109
+ Args:
110
+ x (Tensor): :math:`x_t`, the input tensor.
111
+ t (Tensor): The time tensor (denoted as :math:`r` above).
112
+ **extras: Additional arguments for the model.
113
+ Returns:
114
+ Tensor: The transformed velocity.
115
+ """
116
+ r = t
117
+
118
+ r_scheduler_output = self.new_scheduler(t=r)
119
+
120
+ alpha_r = r_scheduler_output.alpha_t
121
+ sigma_r = r_scheduler_output.sigma_t
122
+ d_alpha_r = r_scheduler_output.d_alpha_t
123
+ d_sigma_r = r_scheduler_output.d_sigma_t
124
+
125
+ t = self.original_scheduler.snr_inverse(alpha_r / sigma_r)
126
+
127
+ t_scheduler_output = self.original_scheduler(t=t)
128
+
129
+ alpha_t = t_scheduler_output.alpha_t
130
+ sigma_t = t_scheduler_output.sigma_t
131
+ d_alpha_t = t_scheduler_output.d_alpha_t
132
+ d_sigma_t = t_scheduler_output.d_sigma_t
133
+
134
+ s_r = sigma_r / sigma_t
135
+
136
+ dt_r = (
137
+ sigma_t
138
+ * sigma_t
139
+ * (sigma_r * d_alpha_r - alpha_r * d_sigma_r)
140
+ / (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t))
141
+ )
142
+
143
+ ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t)
144
+
145
+ u_t = self.model(x=x / s_r, t=t, **extras)
146
+ u_r = ds_r * x / s_r + dt_r * s_r * u_t
147
+
148
+ return u_r
flow_matching/path/scheduler/scheduler.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass, field
9
+
10
+ from typing import Union
11
+
12
+ import torch
13
+
14
+ from torch import Tensor
15
+
16
+
17
+ @dataclass
18
+ class SchedulerOutput:
19
+ r"""Represents a sample of a conditional-flow generated probability path.
20
+
21
+ Attributes:
22
+ alpha_t (Tensor): :math:`\alpha_t`, shape (...).
23
+ sigma_t (Tensor): :math:`\sigma_t`, shape (...).
24
+ d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...).
25
+ d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...).
26
+
27
+ """
28
+
29
+ alpha_t: Tensor = field(metadata={"help": "alpha_t"})
30
+ sigma_t: Tensor = field(metadata={"help": "sigma_t"})
31
+ d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."})
32
+ d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."})
33
+
34
+
35
+ class Scheduler(ABC):
36
+ """Base Scheduler class."""
37
+
38
+ @abstractmethod
39
+ def __call__(self, t: Tensor) -> SchedulerOutput:
40
+ r"""
41
+ Args:
42
+ t (Tensor): times in [0,1], shape (...).
43
+
44
+ Returns:
45
+ SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
46
+ """
47
+ ...
48
+
49
+ @abstractmethod
50
+ def snr_inverse(self, snr: Tensor) -> Tensor:
51
+ r"""
52
+ Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
53
+
54
+ Args:
55
+ snr (Tensor): The signal-to-noise, shape (...)
56
+
57
+ Returns:
58
+ Tensor: t, shape (...)
59
+ """
60
+ ...
61
+
62
+
63
+ class ConvexScheduler(Scheduler):
64
+ @abstractmethod
65
+ def __call__(self, t: Tensor) -> SchedulerOutput:
66
+ """Scheduler for convex paths.
67
+
68
+ Args:
69
+ t (Tensor): times in [0,1], shape (...).
70
+
71
+ Returns:
72
+ SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
73
+ """
74
+ ...
75
+
76
+ @abstractmethod
77
+ def kappa_inverse(self, kappa: Tensor) -> Tensor:
78
+ """
79
+ Computes :math:`t` from :math:`\kappa_t`.
80
+
81
+ Args:
82
+ kappa (Tensor): :math:`\kappa`, shape (...)
83
+
84
+ Returns:
85
+ Tensor: t, shape (...)
86
+ """
87
+ ...
88
+
89
+ def snr_inverse(self, snr: Tensor) -> Tensor:
90
+ r"""
91
+ Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
92
+
93
+ Args:
94
+ snr (Tensor): The signal-to-noise, shape (...)
95
+
96
+ Returns:
97
+ Tensor: t, shape (...)
98
+ """
99
+ kappa_t = snr / (1.0 + snr)
100
+
101
+ return self.kappa_inverse(kappa=kappa_t)
102
+
103
+
104
+ class CondOTScheduler(ConvexScheduler):
105
+ """CondOT Scheduler."""
106
+
107
+ def __call__(self, t: Tensor) -> SchedulerOutput:
108
+ return SchedulerOutput(
109
+ alpha_t=t,
110
+ sigma_t=1 - t,
111
+ d_alpha_t=torch.ones_like(t),
112
+ d_sigma_t=-torch.ones_like(t),
113
+ )
114
+
115
+ def kappa_inverse(self, kappa: Tensor) -> Tensor:
116
+ return kappa
117
+
118
+
119
+ class PolynomialConvexScheduler(ConvexScheduler):
120
+ """Polynomial Scheduler."""
121
+
122
+ def __init__(self, n: Union[float, int]) -> None:
123
+ assert isinstance(
124
+ n, (float, int)
125
+ ), f"`n` must be a float or int. Got {type(n)=}."
126
+ assert n > 0, f"`n` must be positive. Got {n=}."
127
+
128
+ self.n = n
129
+
130
+ def __call__(self, t: Tensor) -> SchedulerOutput:
131
+ return SchedulerOutput(
132
+ alpha_t=t**self.n,
133
+ sigma_t=1 - t**self.n,
134
+ d_alpha_t=self.n * (t ** (self.n - 1)),
135
+ d_sigma_t=-self.n * (t ** (self.n - 1)),
136
+ )
137
+
138
+ def kappa_inverse(self, kappa: Tensor) -> Tensor:
139
+ return torch.pow(kappa, 1.0 / self.n)
140
+
141
+
142
+ class VPScheduler(Scheduler):
143
+ """Variance Preserving Scheduler."""
144
+
145
+ def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None:
146
+ self.beta_min = beta_min
147
+ self.beta_max = beta_max
148
+ super().__init__()
149
+
150
+ def __call__(self, t: Tensor) -> SchedulerOutput:
151
+ b = self.beta_min
152
+ B = self.beta_max
153
+ T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b
154
+ dT = -(1 - t) * (B - b) - b
155
+
156
+ return SchedulerOutput(
157
+ alpha_t=torch.exp(-0.5 * T),
158
+ sigma_t=torch.sqrt(1 - torch.exp(-T)),
159
+ d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T),
160
+ d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)),
161
+ )
162
+
163
+ def snr_inverse(self, snr: Tensor) -> Tensor:
164
+ T = -torch.log(snr**2 / (snr**2 + 1))
165
+ b = self.beta_min
166
+ B = self.beta_max
167
+ t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b))
168
+ return t
169
+
170
+
171
+ class LinearVPScheduler(Scheduler):
172
+ """Linear Variance Preserving Scheduler."""
173
+
174
+ def __call__(self, t: Tensor) -> SchedulerOutput:
175
+ return SchedulerOutput(
176
+ alpha_t=t,
177
+ sigma_t=(1 - t**2) ** 0.5,
178
+ d_alpha_t=torch.ones_like(t),
179
+ d_sigma_t=-t / (1 - t**2) ** 0.5,
180
+ )
181
+
182
+ def snr_inverse(self, snr: Tensor) -> Tensor:
183
+ return torch.sqrt(snr**2 / (1 + snr**2))
184
+
185
+
186
+ class CosineScheduler(Scheduler):
187
+ """Cosine Scheduler."""
188
+
189
+ def __call__(self, t: Tensor) -> SchedulerOutput:
190
+ pi = torch.pi
191
+ return SchedulerOutput(
192
+ alpha_t=torch.sin(pi / 2 * t),
193
+ sigma_t=torch.cos(pi / 2 * t),
194
+ d_alpha_t=pi / 2 * torch.cos(pi / 2 * t),
195
+ d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t),
196
+ )
197
+
198
+ def snr_inverse(self, snr: Tensor) -> Tensor:
199
+ return 2.0 * torch.atan(snr) / torch.pi
flow_matching/solver/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .discrete_solver import MixtureDiscreteEulerSolver
8
+ from .ode_solver import ODESolver
9
+ from .riemannian_ode_solver import RiemannianODESolver
10
+ from .solver import Solver
11
+
12
+ __all__ = [
13
+ "ODESolver",
14
+ "Solver",
15
+ "ModelWrapper",
16
+ "MixtureDiscreteEulerSolver",
17
+ "RiemannianODESolver",
18
+ ]
flow_matching/solver/discrete_solver.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from contextlib import nullcontext
8
+ from math import ceil
9
+ from typing import Callable, Optional, Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ import gc
14
+ from torch.nn import functional as F
15
+
16
+ from flow_matching.path import MixtureDiscreteProbPath
17
+
18
+ from flow_matching.solver.solver import Solver
19
+ from flow_matching.utils import categorical, ModelWrapper
20
+ from .utils import get_nearest_times
21
+ from ..utils.multi_guidance import *
22
+
23
+ try:
24
+ from tqdm import tqdm
25
+
26
+ TQDM_AVAILABLE = True
27
+ except ImportError:
28
+ TQDM_AVAILABLE = False
29
+
30
+
31
+ class MixtureDiscreteEulerSolver(Solver):
32
+ r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``.
33
+ Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is:
34
+
35
+ .. math::
36
+
37
+ \begin{align*}
38
+ & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\
39
+ & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\
40
+ & Z^i_{\text{change}} \sim U[0,1]\\
41
+ & X_{t+h}^i \sim \begin{cases}
42
+ \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\
43
+ \delta_{X_t^i}(\cdot) \text{ else }
44
+ \end{cases}
45
+ \end{align*}
46
+
47
+ Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is:
48
+
49
+ .. math::
50
+
51
+ u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right],
52
+
53
+ where
54
+
55
+ .. math::
56
+ \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right],
57
+
58
+ and
59
+
60
+ .. math::
61
+
62
+ \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right].
63
+
64
+ The source distribution :math:`p(x^i)` is given by ``p``.
65
+
66
+ Args:
67
+ model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size].
68
+ path (MixtureDiscreteProbPath): Probability path used for x-prediction training.
69
+ vocabulary_size (int): size of the discrete vocabulary.
70
+ source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model: ModelWrapper,
76
+ path: MixtureDiscreteProbPath,
77
+ vocabulary_size: int,
78
+ source_distribution_p: Optional[Tensor] = None,
79
+ ):
80
+ super().__init__()
81
+ self.model = model
82
+ self.path = path
83
+ self.vocabulary_size = vocabulary_size
84
+
85
+ if source_distribution_p is not None:
86
+ assert source_distribution_p.shape == torch.Size(
87
+ [vocabulary_size]
88
+ ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}."
89
+
90
+ self.source_distribution_p = source_distribution_p
91
+
92
+ @torch.no_grad()
93
+ def sample(
94
+ self,
95
+ x_init: Tensor,
96
+ step_size: Optional[float],
97
+ div_free: Union[float, Callable[[float], float]] = 0.0,
98
+ dtype_categorical: torch.dtype = torch.float32,
99
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
100
+ return_intermediates: bool = False,
101
+ verbose: bool = False,
102
+ **model_extras,
103
+ ) -> Tensor:
104
+ """
105
+ Sample a sequence of discrete values from the given model.
106
+
107
+ .. code-block:: python
108
+
109
+ import torch
110
+ from flow_matching.utils import ModelWrapper
111
+ from flow_matching.solver import MixtureDiscreteEulerSolver
112
+
113
+ class DummyModel(ModelWrapper):
114
+ def __init__(self):
115
+ super().__init__(None)
116
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
117
+ return ...
118
+
119
+ model = DummyModel()
120
+ solver = MixtureDiscreteEulerSolver(model=model)
121
+
122
+ x_init = torch.LongTensor([122, 725])
123
+ step_size = 0.001
124
+ time_grid = torch.tensor([0.0, 1.0])
125
+
126
+ result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
127
+
128
+ Args:
129
+ x_init (Tensor): The initial state.
130
+ step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid.
131
+ div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0.
132
+ dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32.
133
+ time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
134
+ return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False.
135
+ verbose (bool): Whether to print progress bars. Defaults to False.
136
+ **model_extras: Additional input for the model.
137
+
138
+ Returns:
139
+ Tensor: The sampled sequence of discrete values.
140
+
141
+ Raises:
142
+ ImportError: To run in verbose mode, tqdm must be installed.
143
+ """
144
+ if not div_free == 0.0:
145
+ assert (
146
+ self.source_distribution_p is not None
147
+ ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity."
148
+
149
+ # Initialize the current state `x_t` with the initial state `X_0`.
150
+ time_grid = time_grid.to(device=x_init.device)
151
+
152
+ if step_size is None:
153
+ # If step_size is None then set the t discretization to time_grid.
154
+ t_discretization = time_grid
155
+ n_steps = len(time_grid) - 1
156
+ else:
157
+ # If step_size is float then t discretization is uniform with step size set by step_size.
158
+ t_init = time_grid[0].item()
159
+ t_final = time_grid[-1].item()
160
+ assert (
161
+ t_final - t_init
162
+ ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
163
+
164
+ n_steps = ceil((t_final - t_init) / step_size)
165
+ t_discretization = torch.tensor(
166
+ [t_init + step_size * i for i in range(n_steps)] + [t_final],
167
+ device=x_init.device,
168
+ )
169
+
170
+ if return_intermediates:
171
+ # get order of intermediate steps:
172
+ order = torch.argsort(time_grid)
173
+ # Compute intermediate steps to return via nearest points in t_discretization to time_grid.
174
+ time_grid = get_nearest_times(
175
+ time_grid=time_grid, t_discretization=t_discretization
176
+ )
177
+
178
+ x_t = x_init.clone()
179
+ steps_counter = 0
180
+ res = []
181
+
182
+ if return_intermediates:
183
+ res = [x_init.clone()]
184
+
185
+ if verbose:
186
+ if not TQDM_AVAILABLE:
187
+ raise ImportError(
188
+ "tqdm is required for verbose mode. Please install it."
189
+ )
190
+ ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
191
+ else:
192
+ ctx = nullcontext()
193
+
194
+ with ctx:
195
+ for i in range(n_steps):
196
+ t = t_discretization[i : i + 1]
197
+ h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1]
198
+
199
+ # Sample x_1 ~ p_1|t( \cdot |x_t)
200
+ p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras)
201
+ x_1 = categorical(p_1t.to(dtype=dtype_categorical))
202
+
203
+ # Checks if final step
204
+ if i == n_steps - 1:
205
+ x_t = x_1
206
+ else:
207
+ # Compute u_t(x|x_t,x_1)
208
+ scheduler_output = self.path.scheduler(t=t)
209
+
210
+ k_t = scheduler_output.alpha_t
211
+ d_k_t = scheduler_output.d_alpha_t
212
+
213
+ delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to(
214
+ k_t.dtype
215
+ ) # [B, L, V]
216
+ u = d_k_t / (1 - k_t) * delta_1
217
+
218
+ # Add divergence-free part
219
+ div_free_t = div_free(t) if callable(div_free) else div_free
220
+
221
+ if div_free_t > 0:
222
+ p_0 = self.source_distribution_p[(None,) * x_t.dim()]
223
+ u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * (
224
+ (1 - k_t) * p_0 + k_t * delta_1
225
+ )
226
+
227
+ # Set u_t(x_t|x_t,x_1) = 0
228
+ delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) # [B, L, V]
229
+ u = torch.where(
230
+ delta_t.to(dtype=torch.bool), torch.zeros_like(u), u
231
+ )
232
+ # import pdb
233
+ # if i % 10 == 0:
234
+ # pdb.set_trace()
235
+ # Sample x_t ~ u_t( \cdot |x_t,x_1)
236
+ intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0
237
+ mask_jump = torch.rand(size=x_t.shape, device=x_t.device) < 1 - torch.exp(-h * intensity)
238
+
239
+ if mask_jump.sum() > 0:
240
+ x_t[mask_jump] = categorical(
241
+ u[mask_jump].to(dtype=dtype_categorical)
242
+ )
243
+
244
+ steps_counter += 1
245
+ t = t + h
246
+
247
+ if return_intermediates and (t in time_grid):
248
+ res.append(x_t.clone())
249
+
250
+ if verbose:
251
+ ctx.n = t.item()
252
+ ctx.refresh()
253
+ ctx.set_description(f"NFE: {steps_counter}")
254
+
255
+ if return_intermediates:
256
+ if step_size is None:
257
+ return torch.stack(res, dim=0)
258
+ else:
259
+ return torch.stack(res, dim=0)[order]
260
+ else:
261
+ return x_t
262
+
263
+
264
+ @torch.no_grad()
265
+ def multi_guidance_sample(
266
+ self,
267
+ args,
268
+ x_init: Tensor,
269
+ step_size: Optional[float],
270
+ div_free: Union[float, Callable[[float], float]] = 0.0,
271
+ dtype_categorical: torch.dtype = torch.float32,
272
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
273
+ return_intermediates: bool = False,
274
+ verbose: bool = False,
275
+ score_models: list = None,
276
+ num_objectives: int = 1,
277
+ weights: list = None,
278
+ **model_extras,
279
+ ) -> Tensor:
280
+
281
+ # score_list_0 = []
282
+ # score_list_1 = []
283
+ # score_list_2 = []
284
+ # score_list_3 = []
285
+ # score_list_4 = []
286
+ # score_list_5 = []
287
+
288
+ import pdb
289
+
290
+ if not div_free == 0.0:
291
+ raise NotImplementedError
292
+
293
+ # Initialize the current state `x_t` with the initial state `X_0`.
294
+ time_grid = time_grid.to(device=x_init.device)
295
+
296
+ if step_size is None:
297
+ # If step_size is None then set the t discretization to time_grid.
298
+ t_discretization = time_grid
299
+ n_steps = len(time_grid) - 1
300
+ else:
301
+ # If step_size is float then t discretization is uniform with step size set by step_size.
302
+ t_init = time_grid[0].item()
303
+ t_final = time_grid[-1].item()
304
+ assert (
305
+ t_final - t_init
306
+ ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
307
+
308
+ n_steps = ceil((t_final - t_init) / step_size)
309
+ t_discretization = torch.tensor(
310
+ [t_init + step_size * i for i in range(n_steps)] + [t_final],
311
+ device=x_init.device,
312
+ )
313
+
314
+ if return_intermediates:
315
+ # get order of intermediate steps:
316
+ order = torch.argsort(time_grid)
317
+ # Compute intermediate steps to return via nearest points in t_discretization to time_grid.
318
+ time_grid = get_nearest_times(
319
+ time_grid=time_grid, t_discretization=t_discretization
320
+ )
321
+
322
+ x_t = x_init.clone()
323
+ steps_counter = 0
324
+ res = []
325
+
326
+ if return_intermediates:
327
+ res = [x_init.clone()]
328
+
329
+ if verbose:
330
+ if not TQDM_AVAILABLE:
331
+ raise ImportError(
332
+ "tqdm is required for verbose mode. Please install it."
333
+ )
334
+ ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
335
+ else:
336
+ ctx = nullcontext()
337
+
338
+ # Randomly sample a weight vector
339
+ if weights is not None:
340
+ w = torch.tensor(weights).to(device=x_init.device)
341
+ else:
342
+ w, _ = select_random_weight_vector(num_objectives, args.num_div)
343
+ # w = torch.tensor([0.2, 0.7, 0.05, 0.05]).to(x_t.device)
344
+ w = w.to(device=x_init.device)
345
+ print(f"Weight Vector: {w}")
346
+ Phi = args.Phi_init
347
+ ema_r_t = None
348
+
349
+ with ctx:
350
+ for i in range(n_steps):
351
+ t = t_discretization[i : i + 1]
352
+ h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1]
353
+
354
+ p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras)
355
+ x_1 = categorical(p_1t.to(dtype=dtype_categorical))
356
+
357
+ # Checks if final step
358
+ if i != n_steps - 1:
359
+ # Compute u_t(y,x)
360
+ scheduler_output = self.path.scheduler(t=t)
361
+ k_t = scheduler_output.alpha_t
362
+ d_k_t = scheduler_output.d_alpha_t
363
+ u_t = d_k_t / (1 - k_t) * p_1t
364
+
365
+ guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, args)
366
+
367
+ best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
368
+
369
+ # best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t)
370
+
371
+ # best_candidate = get_best_candidate(improvement_values, cand_tokens, delta_S)
372
+
373
+ x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h)
374
+
375
+
376
+ steps_counter += 1
377
+ t = t + h
378
+
379
+ scores = []
380
+ for i, s in enumerate(score_models):
381
+ sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
382
+ if 't' in sig.parameters:
383
+ candidate_scores = s(x_t, 1)
384
+ else:
385
+ candidate_scores = s(x_t)
386
+
387
+ if isinstance(candidate_scores, tuple):
388
+ for score in candidate_scores:
389
+ scores.append(score.item())
390
+ else:
391
+ scores.append(candidate_scores.item())
392
+ print(scores)
393
+
394
+ # print(f"Score {i}: {[round(s.item(), 4) for s in candidate_scores]}")
395
+ # if i == 0:
396
+ # score_list_0.append(round(candidate_scores[0].item(), 2))
397
+ # # score_list_0.append(round(1-candidate_scores.item(), 2))
398
+ # # score_list_1.append(round(candidate_scores[1].item(), 2))
399
+ # if i == 1:
400
+ # score_list_1.append(round(candidate_scores.item(), 2))
401
+ # # score_list_2.append(round(candidate_scores.item(), 2))
402
+ # if i == 2:
403
+ # score_list_2.append(round(candidate_scores.item(), 2))
404
+ # if i == 3:
405
+ # score_list_3.append(round(candidate_scores.item(), 2))
406
+ # if i == 4:
407
+ # score_list_4.append(round(candidate_scores.item(), 2))
408
+ # if i == 5:
409
+ # score_list_5.append(round(candidate_scores.item(), 2))
410
+
411
+
412
+ if return_intermediates and (t in time_grid):
413
+ res.append(x_t.clone())
414
+
415
+ if verbose:
416
+ ctx.n = t.item()
417
+ ctx.refresh()
418
+ ctx.set_description(f"NFE: {steps_counter}")
419
+
420
+ # print(score_list)
421
+ if return_intermediates:
422
+ if step_size is None:
423
+ return torch.stack(res, dim=0)
424
+ else:
425
+ return torch.stack(res, dim=0)[order]
426
+ else:
427
+ # return x_t, score_list_0, score_list_1, score_list_2, score_list_3, score_list_4, score_list_5
428
+ return x_t
flow_matching/solver/ode_solver.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional, Sequence, Tuple, Union
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torchdiffeq import odeint
12
+
13
+ from flow_matching.solver.solver import Solver
14
+ from flow_matching.utils import gradient, ModelWrapper
15
+
16
+
17
+ class ODESolver(Solver):
18
+ """A class to solve ordinary differential equations (ODEs) using a specified velocity model.
19
+
20
+ This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
21
+
22
+ Args:
23
+ velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)`
24
+ """
25
+
26
+ def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
27
+ super().__init__()
28
+ self.velocity_model = velocity_model
29
+
30
+ def sample(
31
+ self,
32
+ x_init: Tensor,
33
+ step_size: Optional[float],
34
+ method: str = "euler",
35
+ atol: float = 1e-5,
36
+ rtol: float = 1e-5,
37
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
38
+ return_intermediates: bool = False,
39
+ enable_grad: bool = False,
40
+ **model_extras,
41
+ ) -> Union[Tensor, Sequence[Tensor]]:
42
+ r"""Solve the ODE with the velocity field.
43
+
44
+ Example:
45
+
46
+ .. code-block:: python
47
+
48
+ import torch
49
+ from flow_matching.utils import ModelWrapper
50
+ from flow_matching.solver import ODESolver
51
+
52
+ class DummyModel(ModelWrapper):
53
+ def __init__(self):
54
+ super().__init__(None)
55
+
56
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
57
+ return torch.ones_like(x) * 3.0 * t**2
58
+
59
+ velocity_model = DummyModel()
60
+ solver = ODESolver(velocity_model=velocity_model)
61
+ x_init = torch.tensor([0.0, 0.0])
62
+ step_size = 0.001
63
+ time_grid = torch.tensor([0.0, 1.0])
64
+
65
+ result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
66
+
67
+ Args:
68
+ x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...].
69
+ step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
70
+ method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
71
+ atol (float): Absolute tolerance, used for adaptive step solvers.
72
+ rtol (float): Relative tolerance, used for adaptive step solvers.
73
+ time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
74
+ return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
75
+ enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
76
+ **model_extras: Additional input for the model.
77
+
78
+ Returns:
79
+ Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
80
+ """
81
+
82
+ time_grid = time_grid.to(x_init.device)
83
+
84
+ def ode_func(t, x):
85
+ return self.velocity_model(x=x, t=t, **model_extras)
86
+
87
+ ode_opts = {"step_size": step_size} if step_size is not None else {}
88
+
89
+ with torch.set_grad_enabled(enable_grad):
90
+ # Approximate ODE solution with numerical ODE solver
91
+ sol = odeint(
92
+ ode_func,
93
+ x_init,
94
+ time_grid,
95
+ method=method,
96
+ options=ode_opts,
97
+ atol=atol,
98
+ rtol=rtol,
99
+ )
100
+
101
+ if return_intermediates:
102
+ return sol
103
+ else:
104
+ return sol[-1]
105
+
106
+ def compute_likelihood(
107
+ self,
108
+ x_1: Tensor,
109
+ log_p0: Callable[[Tensor], Tensor],
110
+ step_size: Optional[float],
111
+ method: str = "euler",
112
+ atol: float = 1e-5,
113
+ rtol: float = 1e-5,
114
+ time_grid: Tensor = torch.tensor([1.0, 0.0]),
115
+ return_intermediates: bool = False,
116
+ exact_divergence: bool = False,
117
+ enable_grad: bool = False,
118
+ **model_extras,
119
+ ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
120
+ r"""Solve for log likelihood given a target sample at :math:`t=0`.
121
+
122
+ Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x.
123
+ The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`.
124
+
125
+ Args:
126
+ x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`).
127
+ log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution.
128
+ step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
129
+ method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
130
+ atol (float): Absolute tolerance, used for adaptive step solvers.
131
+ rtol (float): Relative tolerance, used for adaptive step solvers.
132
+ time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
133
+ return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
134
+ exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
135
+ enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
136
+ **model_extras: Additional input for the model.
137
+
138
+ Returns:
139
+ Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1.
140
+ """
141
+ assert (
142
+ time_grid[0] == 1.0 and time_grid[-1] == 0.0
143
+ ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}"
144
+
145
+ # Fix the random projection for the Hutchinson divergence estimator
146
+ if not exact_divergence:
147
+ z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0
148
+
149
+ def ode_func(x, t):
150
+ return self.velocity_model(x=x, t=t, **model_extras)
151
+
152
+ def dynamics_func(t, states):
153
+ xt = states[0]
154
+ with torch.set_grad_enabled(True):
155
+ xt.requires_grad_()
156
+ ut = ode_func(xt, t)
157
+
158
+ if exact_divergence:
159
+ # Compute exact divergence
160
+ div = 0
161
+ for i in range(ut.flatten(1).shape[1]):
162
+ div += gradient(ut[:, i], xt, create_graph=True)[:, i]
163
+ else:
164
+ # Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
165
+ ut_dot_z = torch.einsum(
166
+ "ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1)
167
+ )
168
+ grad_ut_dot_z = gradient(ut_dot_z, xt)
169
+ div = torch.einsum(
170
+ "ij,ij->i",
171
+ grad_ut_dot_z.flatten(start_dim=1),
172
+ z.flatten(start_dim=1),
173
+ )
174
+
175
+ return ut.detach(), div.detach()
176
+
177
+ y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
178
+ ode_opts = {"step_size": step_size} if step_size is not None else {}
179
+
180
+ with torch.set_grad_enabled(enable_grad):
181
+ sol, log_det = odeint(
182
+ dynamics_func,
183
+ y_init,
184
+ time_grid,
185
+ method=method,
186
+ options=ode_opts,
187
+ atol=atol,
188
+ rtol=rtol,
189
+ )
190
+
191
+ x_source = sol[-1]
192
+ source_log_p = log_p0(x_source)
193
+
194
+ if return_intermediates:
195
+ return sol, source_log_p + log_det[-1]
196
+ else:
197
+ return sol[-1], source_log_p + log_det[-1]
flow_matching/solver/riemannian_ode_solver.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Callable
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from flow_matching.solver.solver import Solver
14
+ from flow_matching.utils import ModelWrapper
15
+ from flow_matching.utils.manifolds import geodesic, Manifold
16
+
17
+ try:
18
+ from tqdm import tqdm
19
+
20
+ TQDM_AVAILABLE = True
21
+ except ImportError:
22
+ TQDM_AVAILABLE = False
23
+
24
+
25
+ class RiemannianODESolver(Solver):
26
+ r"""Riemannian ODE solver
27
+ Initialize the ``RiemannianODESolver``.
28
+
29
+ Args:
30
+ manifold (Manifold): the manifold to solve on.
31
+ velocity_model (ModelWrapper): a velocity field model receiving :math:`(x,t)`
32
+ and returning :math:`u_t(x)` which is assumed to lie on the tangent plane at `x`.
33
+ """
34
+
35
+ def __init__(self, manifold: Manifold, velocity_model: ModelWrapper):
36
+ super().__init__()
37
+ self.manifold = manifold
38
+ self.velocity_model = velocity_model
39
+
40
+ def sample(
41
+ self,
42
+ x_init: Tensor,
43
+ step_size: float,
44
+ projx: bool = True,
45
+ proju: bool = True,
46
+ method: str = "euler",
47
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
48
+ return_intermediates: bool = False,
49
+ verbose: bool = False,
50
+ enable_grad: bool = False,
51
+ **model_extras,
52
+ ) -> Tensor:
53
+ r"""Solve the ODE with the `velocity_field` on the manifold.
54
+
55
+ Args:
56
+ x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`).
57
+ step_size (float): The step size.
58
+ projx (bool): Whether to project the point onto the manifold at each step. Defaults to True.
59
+ proju (bool): Whether to project the vector field onto the tangent plane at each step. Defaults to True.
60
+ method (str): One of ["euler", "midpoint", "rk4"]. Defaults to "euler".
61
+ time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
62
+ return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
63
+ verbose (bool, optional): Whether to print progress bars. Defaults to False.
64
+ enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
65
+ **model_extras: Additional input for the model.
66
+
67
+ Returns:
68
+ Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`.
69
+
70
+ Raises:
71
+ ImportError: To run in verbose mode, tqdm must be installed.
72
+ """
73
+ step_fns = {
74
+ "euler": _euler_step,
75
+ "midpoint": _midpoint_step,
76
+ "rk4": _rk4_step,
77
+ }
78
+ assert method in step_fns.keys(), f"Unknown method {method}"
79
+ step_fn = step_fns[method]
80
+
81
+ def velocity_func(x, t):
82
+ return self.velocity_model(x=x, t=t, **model_extras)
83
+
84
+ # --- Factor this out.
85
+ time_grid = torch.sort(time_grid.to(device=x_init.device)).values
86
+
87
+ if step_size is None:
88
+ # If step_size is None then set the t discretization to time_grid.
89
+ t_discretization = time_grid
90
+ n_steps = len(time_grid) - 1
91
+ else:
92
+ # If step_size is float then t discretization is uniform with step size set by step_size.
93
+ t_init = time_grid[0].item()
94
+ t_final = time_grid[-1].item()
95
+ assert (
96
+ t_final - t_init
97
+ ) > step_size, f"Time interval [min(time_grid), max(time_grid)] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
98
+
99
+ n_steps = math.ceil((t_final - t_init) / step_size)
100
+ t_discretization = torch.tensor(
101
+ [step_size * i for i in range(n_steps)] + [t_final],
102
+ device=x_init.device,
103
+ )
104
+ # ---
105
+ t0s = t_discretization[:-1]
106
+
107
+ if verbose:
108
+ if not TQDM_AVAILABLE:
109
+ raise ImportError(
110
+ "tqdm is required for verbose mode. Please install it."
111
+ )
112
+ t0s = tqdm(t0s)
113
+
114
+ if return_intermediates:
115
+ xts = []
116
+ i_ret = 0
117
+
118
+ with torch.set_grad_enabled(enable_grad):
119
+ xt = x_init
120
+ for t0, t1 in zip(t0s, t_discretization[1:]):
121
+ dt = t1 - t0
122
+ xt_next = step_fn(
123
+ velocity_func,
124
+ xt,
125
+ t0,
126
+ dt,
127
+ manifold=self.manifold,
128
+ projx=projx,
129
+ proju=proju,
130
+ )
131
+ if return_intermediates:
132
+ while (
133
+ i_ret < len(time_grid)
134
+ and t0 <= time_grid[i_ret]
135
+ and time_grid[i_ret] <= t1
136
+ ):
137
+ xts.append(
138
+ interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret])
139
+ )
140
+ i_ret += 1
141
+ xt = xt_next
142
+
143
+ if return_intermediates:
144
+ return torch.stack(xts, dim=0)
145
+ else:
146
+ return xt
147
+
148
+
149
+ def interp(manifold, xt, xt_next, t, t_next, t_ret):
150
+ return geodesic(manifold, xt, xt_next)(
151
+ (t_ret - t) / (t_next - t).reshape(1)
152
+ ).reshape_as(xt)
153
+
154
+
155
+ def _euler_step(
156
+ velocity_model: Callable,
157
+ xt: Tensor,
158
+ t0: Tensor,
159
+ dt: Tensor,
160
+ manifold: Manifold,
161
+ projx: bool = True,
162
+ proju: bool = True,
163
+ ) -> Tensor:
164
+ r"""Perform an Euler step on a manifold.
165
+
166
+ Args:
167
+ velocity_model (Callable): the velocity model
168
+ xt (Tensor): tensor containing the state at time t0
169
+ t0 (Tensor): the time at which this step is taken
170
+ dt (Tensor): the step size
171
+ manifold (Manifold): a manifold object
172
+ projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
173
+ proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
174
+
175
+ Returns:
176
+ Tensor: tensor containing the state after the step
177
+ """
178
+ velocity_fn = lambda x, t: (
179
+ manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
180
+ )
181
+ projx_fn = lambda x: manifold.projx(x) if projx else x
182
+
183
+ vt = velocity_fn(xt, t0)
184
+
185
+ xt = xt + dt * vt
186
+
187
+ return projx_fn(xt)
188
+
189
+
190
+ def _midpoint_step(
191
+ velocity_model: Callable,
192
+ xt: Tensor,
193
+ t0: Tensor,
194
+ dt: Tensor,
195
+ manifold: Manifold,
196
+ projx: bool = True,
197
+ proju: bool = True,
198
+ ) -> Tensor:
199
+ r"""Perform a midpoint step on a manifold.
200
+
201
+ Args:
202
+ velocity_model (Callable): the velocity model
203
+ xt (Tensor): tensor containing the state at time t0
204
+ t0 (Tensor): the time at which this step is taken
205
+ dt (Tensor): the step size
206
+ manifold (Manifold): a manifold object
207
+ projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
208
+ proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
209
+
210
+ Returns:
211
+ Tensor: tensor containing the state after the step
212
+ """
213
+ velocity_fn = lambda x, t: (
214
+ manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
215
+ )
216
+ projx_fn = lambda x: manifold.projx(x) if projx else x
217
+
218
+ half_dt = 0.5 * dt
219
+ vt = velocity_fn(xt, t0)
220
+ x_mid = xt + half_dt * vt
221
+ x_mid = projx_fn(x_mid)
222
+
223
+ xt = xt + dt * velocity_fn(x_mid, t0 + half_dt)
224
+
225
+ return projx_fn(xt)
226
+
227
+
228
+ def _rk4_step(
229
+ velocity_model: Callable,
230
+ xt: Tensor,
231
+ t0: Tensor,
232
+ dt: Tensor,
233
+ manifold: Manifold,
234
+ projx: bool = True,
235
+ proju: bool = True,
236
+ ) -> Tensor:
237
+ r"""Perform an RK4 step on a manifold.
238
+
239
+ Args:
240
+ velocity_model (Callable): the velocity model
241
+ xt (Tensor): tensor containing the state at time t0
242
+ t0 (Tensor): the time at which this step is taken
243
+ dt (Tensor): the step size
244
+ manifold (Manifold): a manifold object
245
+ projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
246
+ proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
247
+
248
+ Returns:
249
+ Tensor: tensor containing the state after the step
250
+ """
251
+ velocity_fn = lambda x, t: (
252
+ manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
253
+ )
254
+ projx_fn = lambda x: manifold.projx(x) if projx else x
255
+
256
+ k1 = velocity_fn(xt, t0)
257
+ k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3)
258
+ k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3)
259
+ k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt)
260
+
261
+ return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125)
flow_matching/solver/solver.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ from torch import nn, Tensor
10
+
11
+
12
+ class Solver(ABC, nn.Module):
13
+ """Abstract base class for solvers."""
14
+
15
+ @abstractmethod
16
+ def sample(self, x_0: Tensor = None) -> Tensor:
17
+ ...
flow_matching/solver/utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor:
12
+ distances = torch.cdist(
13
+ time_grid.unsqueeze(1),
14
+ t_discretization.unsqueeze(1),
15
+ compute_mode="donot_use_mm_for_euclid_dist",
16
+ )
17
+ nearest_indices = distances.argmin(dim=1)
18
+
19
+ return t_discretization[nearest_indices]
flow_matching/utils/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .categorical_sampler import categorical
8
+ from .model_wrapper import ModelWrapper
9
+ from .utils import expand_tensor_like, gradient, unsqueeze_to_match
10
+
11
+ __all__ = [
12
+ "unsqueeze_to_match",
13
+ "expand_tensor_like",
14
+ "gradient",
15
+ "categorical",
16
+ "ModelWrapper",
17
+ ]
flow_matching/utils/categorical_sampler.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def categorical(probs: Tensor) -> Tensor:
12
+ r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`.
13
+
14
+ Args:
15
+ probs (Tensor): probabilities.
16
+
17
+ Returns:
18
+ Tensor: Samples.
19
+ """
20
+
21
+ return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view(
22
+ *probs.shape[:-1]
23
+ )
flow_matching/utils/manifolds/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .manifold import Euclidean, Manifold
8
+ from .sphere import Sphere
9
+ from .torus import FlatTorus
10
+ from .utils import geodesic
11
+
12
+ __all__ = [
13
+ "Euclidean",
14
+ "Manifold",
15
+ "Sphere",
16
+ "FlatTorus",
17
+ "geodesic",
18
+ ]
flow_matching/utils/manifolds/manifold.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import abc
8
+
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+
12
+
13
+ class Manifold(nn.Module, metaclass=abc.ABCMeta):
14
+ """A manifold class that contains projection operations and logarithm and exponential maps."""
15
+
16
+ @abc.abstractmethod
17
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
18
+ r"""Computes exponential map :math:`\exp_x(u)`.
19
+
20
+ Args:
21
+ x (Tensor): point on the manifold
22
+ u (Tensor): tangent vector at point :math:`x`
23
+
24
+ Raises:
25
+ NotImplementedError: if not implemented
26
+
27
+ Returns:
28
+ Tensor: transported point
29
+ """
30
+ raise NotImplementedError
31
+
32
+ @abc.abstractmethod
33
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
34
+ r"""Computes logarithmic map :math:`\log_x(y)`.
35
+
36
+ Args:
37
+ x (Tensor): point on the manifold
38
+ y (Tensor): point on the manifold
39
+
40
+ Raises:
41
+ NotImplementedError: if not implemented
42
+
43
+ Returns:
44
+ Tensor: tangent vector at point :math:`x`
45
+ """
46
+ raise NotImplementedError
47
+
48
+ @abc.abstractmethod
49
+ def projx(self, x: Tensor) -> Tensor:
50
+ """Project point :math:`x` on the manifold.
51
+
52
+ Args:
53
+ x (Tensor): point to be projected
54
+
55
+ Raises:
56
+ NotImplementedError: if not implemented
57
+
58
+ Returns:
59
+ Tensor: projected point on the manifold
60
+ """
61
+ raise NotImplementedError
62
+
63
+ @abc.abstractmethod
64
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
65
+ """Project vector :math:`u` on a tangent space for :math:`x`.
66
+
67
+ Args:
68
+ x (Tensor): point on the manifold
69
+ u (Tensor): vector to be projected
70
+
71
+ Raises:
72
+ NotImplementedError: if not implemented
73
+
74
+ Returns:
75
+ Tensor: projected tangent vector
76
+ """
77
+ raise NotImplementedError
78
+
79
+
80
+ class Euclidean(Manifold):
81
+ """The Euclidean manifold."""
82
+
83
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
84
+ return x + u
85
+
86
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
87
+ return y - x
88
+
89
+ def projx(self, x: Tensor) -> Tensor:
90
+ return x
91
+
92
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
93
+ return u
flow_matching/utils/manifolds/sphere.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from flow_matching.utils.manifolds import Manifold
11
+
12
+
13
+ class Sphere(Manifold):
14
+ """Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres."""
15
+
16
+ EPS = {torch.float32: 1e-4, torch.float64: 1e-7}
17
+
18
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
19
+ norm_u = u.norm(dim=-1, keepdim=True)
20
+ exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u
21
+ retr = self.projx(x + u)
22
+ cond = norm_u > self.EPS[norm_u.dtype]
23
+
24
+ return torch.where(cond, exp, retr)
25
+
26
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
27
+ u = self.proju(x, y - x)
28
+ dist = self.dist(x, y, keepdim=True)
29
+ cond = dist.gt(self.EPS[x.dtype])
30
+ result = torch.where(
31
+ cond,
32
+ u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]),
33
+ u,
34
+ )
35
+ return result
36
+
37
+ def projx(self, x: Tensor) -> Tensor:
38
+ return x / x.norm(dim=-1, keepdim=True)
39
+
40
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
41
+ return u - (x * u).sum(dim=-1, keepdim=True) * x
42
+
43
+ def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor:
44
+ inner = (x * y).sum(-1, keepdim=keepdim)
45
+ return torch.acos(inner)
flow_matching/utils/manifolds/torus.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from flow_matching.utils.manifolds import Manifold
13
+
14
+
15
+ class FlatTorus(Manifold):
16
+ r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres."""
17
+
18
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
19
+ return (x + u) % (2 * math.pi)
20
+
21
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
22
+ return torch.atan2(torch.sin(y - x), torch.cos(y - x))
23
+
24
+ def projx(self, x: Tensor) -> Tensor:
25
+ return x % (2 * math.pi)
26
+
27
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
28
+ return u
flow_matching/utils/manifolds/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from flow_matching.utils.manifolds import Manifold
13
+
14
+
15
+ def geodesic(
16
+ manifold: Manifold, start_point: Tensor, end_point: Tensor
17
+ ) -> Callable[[Tensor], Tensor]:
18
+ """Generate parameterized function for geodesic curve.
19
+
20
+ Args:
21
+ manifold (Manifold): the manifold to compute geodesic on.
22
+ start_point (Tensor): point on the manifold at :math:`t=0`.
23
+ end_point (Tensor): point on the manifold at :math:`t=1`.
24
+
25
+ Returns:
26
+ Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`.
27
+ """
28
+
29
+ shooting_tangent_vec = manifold.logmap(start_point, end_point)
30
+
31
+ def path(t: Tensor) -> Tensor:
32
+ """Generate parameterized function for geodesic curve.
33
+
34
+ Args:
35
+ t (Tensor): Times at which to compute points of the geodesics.
36
+
37
+ Returns:
38
+ Tensor: geodesic path evaluated at time t.
39
+ """
40
+ tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
41
+ points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
42
+
43
+ return points_at_time_t
44
+
45
+ return path
flow_matching/utils/model_wrapper.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC
8
+
9
+ from torch import nn, Tensor
10
+
11
+
12
+ class ModelWrapper(ABC, nn.Module):
13
+ """
14
+ This class is used to wrap around another model, adding custom forward pass logic.
15
+ """
16
+
17
+ def __init__(self, model: nn.Module):
18
+ super().__init__()
19
+ self.model = model
20
+
21
+ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
22
+ r"""
23
+ This method defines how inputs should be passed through the wrapped model.
24
+ Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input,
25
+ along with any additional keyword arguments.
26
+
27
+ Optional things to do here:
28
+ - check that t is in the dimensions that the model is expecting.
29
+ - add a custom forward pass logic.
30
+ - call the wrapped model.
31
+
32
+ | given x, t
33
+ | returns the model output for input x at time t, with extra information `extra`.
34
+
35
+ Args:
36
+ x (Tensor): input data to the model (batch_size, ...).
37
+ t (Tensor): time (batch_size).
38
+ **extras: additional information forwarded to the model, e.g., text condition.
39
+
40
+ Returns:
41
+ Tensor: model output.
42
+ """
43
+ return self.model(x=x, t=t, **extras)
flow_matching/utils/multi_guidance.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from flow_matching.utils import categorical
3
+ import math
4
+ import inspect
5
+
6
+ def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
7
+ def rec(n, H):
8
+ if n == 1:
9
+ return [[H]]
10
+ points = []
11
+ for i in range(H + 1):
12
+ for tail in rec(n - 1, H - i):
13
+ points.append([i] + tail)
14
+ return points
15
+
16
+ points = rec(num_obj, num_div)
17
+ weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div
18
+ return weight_vectors
19
+
20
+ def select_random_weight_vector(num_obj: int, num_div: int):
21
+ weight_vectors = generate_simplex_lattice_points(num_obj, num_div)
22
+ idx = torch.randint(0, weight_vectors.size(0), (1,)).item()
23
+ random_weight_vector = weight_vectors[idx]
24
+ return random_weight_vector, weight_vectors
25
+
26
+ def z_score_norm(tensor, eps=1e-8):
27
+ mean = tensor.mean(dim=-1, keepdim=True)
28
+ std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
29
+ return (tensor - mean) / std
30
+
31
+ def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
32
+ B, L, vocab_size = u_t.shape
33
+ device = x_t.device
34
+ guided_u_t = u_t.clone()
35
+
36
+ # 1. Randomly select one position per sequence.
37
+ pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
38
+ batch_idx = torch.arange(B, device=device)
39
+ current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
40
+
41
+ # 2. Build candidate tokens for each sequence and remove self-transition.
42
+ full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size)
43
+ mask = (full_cand_tokens != current_tokens.unsqueeze(1)) # (B, vocab_size)
44
+ # Now, cand_tokens contains only candidate tokens that differ from the current token.
45
+ cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 1) # (B, vocab_size-1)
46
+
47
+ # 3. Create candidate sequences by replacing the token at the selected position.
48
+ new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone()
49
+ new_x = new_x[mask].view(B, vocab_size - 1, L) # (B, vocab_size-1, L)
50
+ new_x[batch_idx, :, pos_indices] = cand_tokens
51
+
52
+ new_x_flat = new_x.view(B * (vocab_size - 1), L)
53
+ improvements_list = []
54
+ with torch.no_grad():
55
+ count = 0
56
+ for i, s in enumerate(s_models):
57
+ sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
58
+ if 't' in sig.parameters:
59
+ candidate_scores = s(new_x_flat, t)
60
+ base_score = s(x_t, t)
61
+ else:
62
+ candidate_scores = s(new_x_flat)
63
+ base_score = s(x_t)
64
+
65
+ if isinstance(candidate_scores, tuple):
66
+ for k, score in enumerate(candidate_scores):
67
+ improvement = candidate_scores[k].view(B, vocab_size - 1) - base_score[k].unsqueeze(1)
68
+ improvement = improvement.float()
69
+ improvement *= importance[count]
70
+ improvements_list.append(improvement.unsqueeze(2))
71
+ count += 1
72
+ else:
73
+ improvement = candidate_scores.view(B, vocab_size - 1) - base_score.unsqueeze(1)
74
+ improvement = improvement.float()
75
+ improvement *= importance[count]
76
+ improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
77
+ count += 1
78
+
79
+ improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
80
+ if args.is_peptide:
81
+ improvement_values[:, :4, :] = -10 # Mask non-residue positions
82
+
83
+ # 5. Compute ranking scores I_n
84
+ ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
85
+ I_n = ranks / float(vocab_size - 1)
86
+ avg_I = I_n.mean(dim=2)
87
+ norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1)
88
+
89
+ # 6. Compute directional score D
90
+ D = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
91
+ norm_D = z_score_norm(D) # (B, vocab_size-1)
92
+
93
+ # 7. Combine the scores
94
+ delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1)
95
+
96
+ # 9. Update the guided velocities at the selected positions.
97
+ factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1)
98
+ factor = torch.clamp(factor, min=-100, max=100)
99
+
100
+ guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor
101
+
102
+ # 10. For the self-transition (current token) at the selected position,
103
+ # set its guided velocity to be the negative sum of the updated off-diagonals.
104
+ updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size)
105
+ sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens]
106
+ guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag
107
+
108
+ return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
109
+
110
+ def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
111
+ B, num_candidates, N = improvement_values.shape
112
+ device = improvement_values.device
113
+ eps = 1e-8
114
+
115
+ # Compute norms and angles.
116
+ imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates)
117
+ dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
118
+ w_norm = torch.norm(w) + eps
119
+ cos_angle = dot_product / (imp_norm * w_norm + eps)
120
+ cos_angle = cos_angle.clamp(-1.0, 1.0)
121
+ angles = torch.acos(cos_angle) # (B, num_candidates)
122
+
123
+ valid_mask = angles < math.pi / 2
124
+ accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates)
125
+
126
+ # Determine the best candidate for each sequence.
127
+ # We'll use a loop over batch items (batch size is typically moderate).
128
+ best_candidate = torch.empty(B, dtype=torch.long, device=device)
129
+ for i in range(B):
130
+ # For sequence i, consider only valid candidates.
131
+ if valid_mask[i].any():
132
+ # There is at least one candidate with α^i < π.
133
+ if accepted_mask[i].any():
134
+ # At least one candidate passes the hypercone: choose the one with max delta_S among accepted.
135
+ candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf')))
136
+ else:
137
+ # No candidate was accepted, but some are valid. Select best candidate among valid ones.
138
+ candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf')))
139
+ best_candidate[i] = cand_tokens[i, candidate_idx]
140
+ else:
141
+ # No candidate is valid (all α^i >= π) → self-transition.
142
+ best_candidate[i] = -1
143
+
144
+ # Compute rejection rate only over valid candidates.
145
+ rejection_rates = []
146
+ for i in range(B):
147
+ valid_candidates = valid_mask[i]
148
+ total_valid = valid_candidates.sum().item()
149
+ if total_valid > 0:
150
+ # Among valid candidates, count how many are rejected.
151
+ num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item()
152
+ rejection_rates.append(num_rejected / total_valid)
153
+ if len(rejection_rates) > 0:
154
+ r_t = sum(rejection_rates) / len(rejection_rates)
155
+ else:
156
+ # If no sequence has any valid candidate, set r_t to 0.
157
+ r_t = 0.0
158
+
159
+ if ema_r_t is None:
160
+ ema_r_t = args.tau
161
+
162
+ # Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch.
163
+ if valid_mask.any():
164
+ new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t
165
+ new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device))
166
+ new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item()
167
+ else:
168
+ new_ema_r_t = ema_r_t
169
+ new_Phi = Phi # No update if no valid candidate exists.
170
+
171
+ return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t
172
+
173
+ def get_best_candidate(improvement_values, cand_tokens, delta_S):
174
+ B, num_candidates, N = improvement_values.shape
175
+ device = improvement_values.device
176
+ best_candidate = torch.empty(B, dtype=torch.long, device=device)
177
+
178
+ for i in range(B):
179
+ candidate_idx = torch.argmax(delta_S[i])
180
+ best_candidate[i] = cand_tokens[i, candidate_idx]
181
+
182
+ return best_candidate
183
+
184
+ def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h):
185
+ B, L, V = guided_u_t.shape
186
+ device = x_t.device
187
+ u = torch.zeros_like(guided_u_t)
188
+
189
+ valid_mask = best_candidate != -1
190
+ if valid_mask.any():
191
+ valid_idx = torch.nonzero(valid_mask).squeeze(-1)
192
+ # For these sequences, update the velocity at the selected position and candidate token.
193
+ u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \
194
+ guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]]
195
+
196
+ # Compute intensity at the selected positions.
197
+ # For sequences with no valid candidate (i.e. self-transition), intensity remains zero.
198
+ intensity = torch.zeros(B, device=device)
199
+ if valid_mask.any():
200
+ intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1)
201
+
202
+ # According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)`
203
+ # However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling.
204
+ # To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`.
205
+ # So for faster sampling, we just use `1 - torch.exp(-1 * intensity)`
206
+ p_jump = 1 - torch.exp(-1 * intensity)
207
+
208
+ rand_val = torch.rand(B, device=device)
209
+
210
+ jump_decision = (rand_val < p_jump) & valid_mask
211
+ if True in jump_decision.tolist():
212
+ print("Jump!")
213
+ # For sequences where a jump is decided, update the token at pos_indices to best_candidate.
214
+ x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision]
215
+
216
+ return x_t
flow_matching/utils/multi_guidance_cnp.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from flow_matching.utils import categorical
3
+ import math
4
+ import inspect
5
+ import random
6
+
7
+ def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
8
+ def rec(n, H):
9
+ if n == 1:
10
+ return [[H]]
11
+ points = []
12
+ for i in range(H + 1):
13
+ for tail in rec(n - 1, H - i):
14
+ points.append([i] + tail)
15
+ return points
16
+
17
+ points = rec(num_obj, num_div)
18
+ weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div
19
+ return weight_vectors
20
+
21
+ def select_random_weight_vector(num_obj: int, num_div: int):
22
+ weight_vectors = generate_simplex_lattice_points(num_obj, num_div)
23
+ idx = torch.randint(0, weight_vectors.size(0), (1,)).item()
24
+ random_weight_vector = weight_vectors[idx]
25
+ return random_weight_vector, weight_vectors
26
+
27
+ def z_score_norm(tensor, eps=1e-8):
28
+ mean = tensor.mean(dim=-1, keepdim=True)
29
+ std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
30
+ return (tensor - mean) / std
31
+
32
+ def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
33
+ B, L, vocab_size = u_t.shape
34
+ device = x_t.device
35
+ guided_u_t = u_t.clone()
36
+
37
+ # 1. Randomly select one position per sequence.
38
+ # pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
39
+ pos_indices = torch.tensor([random.choice([i for i in range(1, L-2) if i != 6])]).to(x_t.device)
40
+ batch_idx = torch.arange(B, device=device)
41
+ current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
42
+
43
+ # 2. Build candidate tokens for each sequence and remove self-transition.
44
+ full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size)
45
+ mask = (full_cand_tokens != current_tokens.unsqueeze(1)) & (full_cand_tokens != 23) # (B, vocab_size)
46
+ # Now, cand_tokens contains only candidate tokens that differ from the current token.
47
+ cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 2) # (B, vocab_size-1)
48
+
49
+ # 3. Create candidate sequences by replacing the token at the selected position.
50
+ new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone()
51
+ new_x = new_x[mask].view(B, vocab_size - 2, L) # (B, vocab_size-1, L)
52
+ new_x[batch_idx, :, pos_indices] = cand_tokens
53
+
54
+ new_x_flat = new_x.view(B * (vocab_size - 2), L)
55
+ improvements_list = []
56
+ with torch.no_grad():
57
+ count = 0
58
+ for i, s in enumerate(s_models):
59
+ sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
60
+ if 't' in sig.parameters:
61
+ candidate_scores = s(new_x_flat, t)
62
+ base_score = s(x_t, t)
63
+ else:
64
+ candidate_scores = s(new_x_flat)
65
+ base_score = s(x_t)
66
+
67
+ if isinstance(candidate_scores, tuple):
68
+ for k, score in enumerate(candidate_scores):
69
+ improvement = candidate_scores[k].view(B, vocab_size - 2) - base_score[k].unsqueeze(1)
70
+ improvement = improvement.float()
71
+ improvement *= importance[count]
72
+ improvements_list.append(improvement.unsqueeze(2))
73
+ count += 1
74
+ else:
75
+ improvement = candidate_scores.view(B, vocab_size - 2) - base_score.unsqueeze(1)
76
+ improvement = improvement.float()
77
+ improvement *= importance[count]
78
+ improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
79
+ count += 1
80
+
81
+ improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
82
+ if args.is_peptide:
83
+ improvement_values[:, :4, :] = -10 # Mask non-residue positions
84
+
85
+ # 5. Compute ranking scores I_n
86
+ ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
87
+ I_n = ranks / float(vocab_size - 2)
88
+ avg_I = I_n.mean(dim=2)
89
+ norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1)
90
+
91
+ # 6. Compute directional score D
92
+ D = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
93
+ norm_D = z_score_norm(D) # (B, vocab_size-1)
94
+
95
+ # 7. Combine the scores
96
+ delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1)
97
+
98
+ # 9. Update the guided velocities at the selected positions.
99
+ factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1)
100
+ factor = torch.clamp(factor, min=-100, max=100)
101
+
102
+ guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor
103
+
104
+ # 10. For the self-transition (current token) at the selected position,
105
+ # set its guided velocity to be the negative sum of the updated off-diagonals.
106
+ updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size)
107
+ sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens]
108
+ guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag
109
+
110
+ return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
111
+
112
+ def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
113
+ B, num_candidates, N = improvement_values.shape
114
+ device = improvement_values.device
115
+ eps = 1e-8
116
+
117
+ # Compute norms and angles.
118
+ imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates)
119
+ dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
120
+ w_norm = torch.norm(w) + eps
121
+ cos_angle = dot_product / (imp_norm * w_norm + eps)
122
+ cos_angle = cos_angle.clamp(-1.0, 1.0)
123
+ angles = torch.acos(cos_angle) # (B, num_candidates)
124
+
125
+ valid_mask = angles < math.pi / 2
126
+ accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates)
127
+
128
+ # Determine the best candidate for each sequence.
129
+ # We'll use a loop over batch items (batch size is typically moderate).
130
+ best_candidate = torch.empty(B, dtype=torch.long, device=device)
131
+ for i in range(B):
132
+ # For sequence i, consider only valid candidates.
133
+ if valid_mask[i].any():
134
+ # There is at least one candidate with α^i < π.
135
+ if accepted_mask[i].any():
136
+ # At least one candidate passes the hypercone: choose the one with max delta_S among accepted.
137
+ candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf')))
138
+ else:
139
+ # No candidate was accepted, but some are valid. Select best candidate among valid ones.
140
+ candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf')))
141
+ best_candidate[i] = cand_tokens[i, candidate_idx]
142
+ else:
143
+ # No candidate is valid (all α^i >= π) → self-transition.
144
+ best_candidate[i] = -1
145
+
146
+ # Compute rejection rate only over valid candidates.
147
+ rejection_rates = []
148
+ for i in range(B):
149
+ valid_candidates = valid_mask[i]
150
+ total_valid = valid_candidates.sum().item()
151
+ if total_valid > 0:
152
+ # Among valid candidates, count how many are rejected.
153
+ num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item()
154
+ rejection_rates.append(num_rejected / total_valid)
155
+ if len(rejection_rates) > 0:
156
+ r_t = sum(rejection_rates) / len(rejection_rates)
157
+ else:
158
+ # If no sequence has any valid candidate, set r_t to 0.
159
+ r_t = 0.0
160
+
161
+ if ema_r_t is None:
162
+ ema_r_t = args.tau
163
+
164
+ # Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch.
165
+ if valid_mask.any():
166
+ new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t
167
+ new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device))
168
+ new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item()
169
+ else:
170
+ new_ema_r_t = ema_r_t
171
+ new_Phi = Phi # No update if no valid candidate exists.
172
+
173
+ return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t
174
+
175
+ def get_best_candidate(improvement_values, cand_tokens, delta_S):
176
+ B, num_candidates, N = improvement_values.shape
177
+ device = improvement_values.device
178
+ best_candidate = torch.empty(B, dtype=torch.long, device=device)
179
+
180
+ for i in range(B):
181
+ candidate_idx = torch.argmax(delta_S[i])
182
+ best_candidate[i] = cand_tokens[i, candidate_idx]
183
+
184
+ return best_candidate
185
+
186
+ def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h):
187
+ B, L, V = guided_u_t.shape
188
+ device = x_t.device
189
+ u = torch.zeros_like(guided_u_t)
190
+
191
+ valid_mask = best_candidate != -1
192
+ if valid_mask.any():
193
+ valid_idx = torch.nonzero(valid_mask).squeeze(-1)
194
+ # For these sequences, update the velocity at the selected position and candidate token.
195
+ u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \
196
+ guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]]
197
+
198
+ # Compute intensity at the selected positions.
199
+ # For sequences with no valid candidate (i.e. self-transition), intensity remains zero.
200
+ intensity = torch.zeros(B, device=device)
201
+ if valid_mask.any():
202
+ intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1)
203
+
204
+ # According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)`
205
+ # However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling.
206
+ # To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`.
207
+ # So for faster sampling, we just use `1 - torch.exp(-1 * intensity)`
208
+ p_jump = 1 - torch.exp(-1 * intensity)
209
+
210
+ rand_val = torch.rand(B, device=device)
211
+
212
+ jump_decision = (rand_val < p_jump) & valid_mask
213
+
214
+ # For sequences where a jump is decided, update the token at pos_indices to best_candidate.
215
+ x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision]
216
+
217
+ return x_t
flow_matching/utils/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+
13
+ def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor:
14
+ """
15
+ Unsqueeze the source tensor to match the dimensionality of the target tensor.
16
+
17
+ Args:
18
+ source (Tensor): The source tensor to be unsqueezed.
19
+ target (Tensor): The target tensor to match the dimensionality of.
20
+ how (str, optional): Whether to unsqueeze the source tensor at the beginning
21
+ ("prefix") or end ("suffix"). Defaults to "suffix".
22
+
23
+ Returns:
24
+ Tensor: The unsqueezed source tensor.
25
+ """
26
+ assert (
27
+ how == "prefix" or how == "suffix"
28
+ ), f"{how} is not supported, only 'prefix' and 'suffix' are supported."
29
+
30
+ dim_diff = target.dim() - source.dim()
31
+
32
+ for _ in range(dim_diff):
33
+ if how == "prefix":
34
+ source = source.unsqueeze(0)
35
+ elif how == "suffix":
36
+ source = source.unsqueeze(-1)
37
+
38
+ return source
39
+
40
+
41
+ def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor:
42
+ """`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`,
43
+ expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions.
44
+
45
+ Args:
46
+ input_tensor (Tensor): (batch_size,).
47
+ expand_to (Tensor): (batch_size, ...).
48
+
49
+ Returns:
50
+ Tensor: (batch_size, ...).
51
+ """
52
+ assert input_tensor.ndim == 1, "Input tensor must be a 1d vector."
53
+ assert (
54
+ input_tensor.shape[0] == expand_to.shape[0]
55
+ ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}."
56
+
57
+ dim_diff = expand_to.ndim - input_tensor.ndim
58
+
59
+ t_expanded = input_tensor.clone()
60
+ t_expanded = t_expanded.reshape(-1, *([1] * dim_diff))
61
+
62
+ return t_expanded.expand_as(expand_to)
63
+
64
+
65
+ def gradient(
66
+ output: Tensor,
67
+ x: Tensor,
68
+ grad_outputs: Optional[Tensor] = None,
69
+ create_graph: bool = False,
70
+ ) -> Tensor:
71
+ """
72
+ Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`.
73
+
74
+ Args:
75
+ output (Tensor): [N, D] Output of the function.
76
+ x (Tensor): [N, d_1, d_2, ... ] input
77
+ grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`,
78
+ then will use a tensor of ones
79
+ create_graph (bool): If True, graph of the derivative will be constructed, allowing
80
+ to compute higher order derivative products. Defaults to False.
81
+ Returns:
82
+ Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x.
83
+ """
84
+
85
+ if grad_outputs is None:
86
+ grad_outputs = torch.ones_like(output).detach()
87
+ grad = torch.autograd.grad(
88
+ output, x, grad_outputs=grad_outputs, create_graph=create_graph
89
+ )[0]
90
+ return grad
models/classifier.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import numpy as np
5
+ import copy
6
+ import pdb
7
+
8
+ class GaussianFourierProjection(nn.Module):
9
+ """
10
+ Gaussian random features for encoding time steps.
11
+ """
12
+
13
+ def __init__(self, embed_dim, scale=30.):
14
+ super().__init__()
15
+ # Randomly sample weights during initialization. These weights are fixed
16
+ # during optimization and are not trainable.
17
+ self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
18
+
19
+ def forward(self, x):
20
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
21
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
22
+
23
+ class Dense(nn.Module):
24
+ """
25
+ A fully connected layer that reshapes outputs to feature maps.
26
+ """
27
+
28
+ def __init__(self, input_dim, output_dim):
29
+ super().__init__()
30
+ self.dense = nn.Linear(input_dim, output_dim)
31
+
32
+ def forward(self, x):
33
+ return self.dense(x)[...]
34
+
35
+ class Swish(nn.Module):
36
+ def __init__(self):
37
+ super().__init__()
38
+
39
+ def forward(self, x):
40
+ return torch.sigmoid(x) * x
41
+
42
+ class CNNClassifier(nn.Module):
43
+ def __init__(self, args, alphabet_size, num_cls, classifier=False):
44
+ super().__init__()
45
+ self.alphabet_size = alphabet_size
46
+ self.args = args
47
+ self.classifier = classifier
48
+ self.num_cls = num_cls
49
+
50
+ if self.args.clean_data:
51
+ self.linear = nn.Embedding(self.alphabet_size, embedding_dim=args.hidden_dim)
52
+ else:
53
+ expanded_simplex_input = args.cls_expanded_simplex or not classifier and (args.mode == 'dirichlet' or args.mode == 'riemannian')
54
+ inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1)
55
+ if (args.mode == 'ardm' or args.mode == 'lrar') and not classifier:
56
+ inp_size += 1 # plus one for the mask token of these models
57
+ self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4)
58
+ self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim))
59
+
60
+ self.num_layers = 5 * args.num_cnn_stacks
61
+ self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
62
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
63
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
64
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
65
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)]
66
+ self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
67
+ self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
68
+ self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
69
+ self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
70
+ nn.ReLU(),
71
+ nn.Conv1d(args.hidden_dim, args.hidden_dim if classifier else self.alphabet_size, kernel_size=1))
72
+ self.dropout = nn.Dropout(args.dropout)
73
+ if classifier:
74
+ self.cls_head = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim),
75
+ nn.ReLU(),
76
+ nn.Linear(args.hidden_dim, self.num_cls))
77
+
78
+ if self.args.cls_free_guidance and not self.classifier:
79
+ self.cls_embedder = nn.Embedding(num_embeddings=self.num_cls + 1, embedding_dim=args.hidden_dim)
80
+ self.cls_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
81
+ def forward(self, seq, t, cls = None, return_embedding=False):
82
+ # pdb.set_trace()
83
+ if self.args.clean_data:
84
+ feat = self.linear(seq)
85
+ feat = feat.permute(0, 2, 1)
86
+ else:
87
+ time_emb = F.relu(self.time_embedder(t))
88
+ feat = seq.permute(0, 2, 1)
89
+ feat = F.relu(self.linear(feat))
90
+
91
+ if self.args.cls_free_guidance and not self.classifier and cls is not None:
92
+ # pdb.set_trace()
93
+ cls_emb = self.cls_embedder(cls)
94
+
95
+ for i in range(self.num_layers):
96
+ h = self.dropout(feat.clone())
97
+ if not self.args.clean_data:
98
+ h = h + self.time_layers[i](time_emb)[:, :, None]
99
+ if self.args.cls_free_guidance and not self.classifier and cls is not None:
100
+ h = h + self.cls_layers[i](cls_emb)[:, :, None]
101
+ h = self.norms[i]((h).permute(0, 2, 1))
102
+ h = F.relu(self.convs[i](h.permute(0, 2, 1)))
103
+ if h.shape == feat.shape:
104
+ feat = h + feat
105
+ else:
106
+ feat = h
107
+ feat = self.final_conv(feat)
108
+ feat = feat.permute(0, 2, 1)
109
+ if self.classifier:
110
+ feat = feat.mean(dim=1)
111
+ if return_embedding:
112
+ embedding = self.cls_head[:1](feat)
113
+ return self.cls_head[1:](embedding), embedding
114
+ else:
115
+ return self.cls_head(feat)
116
+ return feat
models/enhancer_models.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import copy
6
+ import pdb
7
+
8
+ class GaussianFourierProjection(nn.Module):
9
+ """
10
+ Gaussian random features for encoding time steps.
11
+ """
12
+
13
+ def __init__(self, embed_dim, scale=30.):
14
+ super().__init__()
15
+ # Randomly sample weights during initialization. These weights are fixed
16
+ # during optimization and are not trainable.
17
+ self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
18
+
19
+ def forward(self, x):
20
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
21
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
22
+
23
+ class Dense(nn.Module):
24
+ """
25
+ A fully connected layer that reshapes outputs to feature maps.
26
+ """
27
+
28
+ def __init__(self, input_dim, output_dim):
29
+ super().__init__()
30
+ self.dense = nn.Linear(input_dim, output_dim)
31
+
32
+ def forward(self, x):
33
+ return self.dense(x)[...]
34
+
35
+ class Swish(nn.Module):
36
+ def __init__(self):
37
+ super().__init__()
38
+
39
+ def forward(self, x):
40
+ return torch.sigmoid(x) * x
41
+
42
+ class CNNModel(nn.Module):
43
+ """A time-dependent score-based model built upon U-Net architecture."""
44
+
45
+ def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
46
+ """
47
+ Args:
48
+ embed_dim (int): Dimensionality of the token and time embeddings.
49
+ """
50
+ super().__init__()
51
+ self.alphabet_size = alphabet_size
52
+
53
+ self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
54
+
55
+ self.time_embed = nn.Sequential(
56
+ GaussianFourierProjection(embed_dim=embed_dim),
57
+ nn.Linear(embed_dim, embed_dim)
58
+ )
59
+
60
+ self.swish = Swish()
61
+
62
+ n = hidden_dim
63
+
64
+ self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
65
+
66
+ self.blocks = nn.ModuleList([
67
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
68
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
69
+ nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
70
+ nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
71
+ nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
72
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
73
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
74
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
75
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
76
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
77
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
78
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
79
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
80
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
81
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
82
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
83
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
84
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
85
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
86
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
87
+ ])
88
+
89
+ self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)])
90
+ self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)])
91
+
92
+ self.final = nn.Sequential(
93
+ nn.Conv1d(n, n, kernel_size=1),
94
+ nn.GELU(),
95
+ nn.Conv1d(n, self.alphabet_size, kernel_size=1)
96
+ )
97
+
98
+
99
+ def forward(self, x, t):
100
+ """
101
+ Args:
102
+ x: Tensor of shape (B, L) containing DNA token indices.
103
+ t: Tensor of shape (B,) containing the time steps.
104
+ Returns:
105
+ out: Tensor of shape (B, L, 4) with output logits for each DNA base.
106
+ """
107
+ x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
108
+
109
+ time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
110
+
111
+ out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
112
+ out = self.swish(self.linear(out)) # (B, n, L)
113
+
114
+ # Process through convolutional blocks, adding time conditioning via dense layers.
115
+ for block, dense, norm in zip(self.blocks, self.denses, self.norms):
116
+ # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
117
+ h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
118
+ # Residual connection if shapes match.
119
+ if h.shape == out.shape:
120
+ out = h + out
121
+ else:
122
+ out = h
123
+
124
+ out = self.final(out) # (B, 4, L)
125
+ out = out.permute(0, 2, 1) # (B, L, 4)
126
+
127
+ # Normalization
128
+ out = out - out.mean(dim=-1, keepdim=True)
129
+ return out
130
+
131
+
132
+ class MLPModel(nn.Module):
133
+ def __init__(
134
+ self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500):
135
+ super().__init__()
136
+ self.input_dim = input_dim
137
+ self.time_dim = time_dim
138
+ self.hidden_dim = hidden_dim
139
+
140
+ self.time_embedding = nn.Linear(1, time_dim)
141
+ self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim)
142
+
143
+ self.swish = Swish()
144
+
145
+ self.main = nn.Sequential(
146
+ self.swish,
147
+ nn.Linear(hidden_dim * length + time_dim, hidden_dim),
148
+ self.swish,
149
+ nn.Linear(hidden_dim, hidden_dim),
150
+ self.swish,
151
+ nn.Linear(hidden_dim, hidden_dim),
152
+ self.swish,
153
+ nn.Linear(hidden_dim, self.input_dim * length),
154
+ )
155
+
156
+ def forward(self, x, t):
157
+ '''
158
+ x shape (B,L)
159
+ t shape (B,)
160
+ '''
161
+ t = self.time_embedding(t.unsqueeze(-1))
162
+ x = self.token_embedding(x)
163
+
164
+ B, N, d = x.shape
165
+ x = x.reshape(B, N * d)
166
+
167
+ h = torch.cat([x, t], dim=1)
168
+ h = self.main(h)
169
+
170
+ h = h.reshape(B, N, self.input_dim)
171
+
172
+ return h
173
+
174
+ class DirichletCNNModel(nn.Module):
175
+ def __init__(self, args, alphabet_size):
176
+ super().__init__()
177
+ self.alphabet_size = alphabet_size
178
+ self.args = args
179
+ expanded_simplex_input = args.cls_expanded_simplex and (args.mode == 'dirichlet' or args.mode == 'riemannian')
180
+ inp_size = self.alphabet_size * (2 if expanded_simplex_input else 1)
181
+ self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4)
182
+ self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim))
183
+
184
+ self.num_layers = 5 * args.num_cnn_stacks
185
+ self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
186
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
187
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
188
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
189
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)]
190
+ self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
191
+ self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
192
+ self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
193
+ self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
194
+ nn.ReLU(),
195
+ nn.Conv1d(args.hidden_dim, self.alphabet_size, kernel_size=1))
196
+ self.dropout = nn.Dropout(args.dropout)
197
+
198
+ def forward(self, seq, t):
199
+ time_emb = F.relu(self.time_embedder(t))
200
+ feat = seq.permute(0, 2, 1)
201
+ feat = F.relu(self.linear(feat))
202
+
203
+ for i in range(self.num_layers):
204
+ h = self.dropout(feat.clone())
205
+ if not self.args.clean_data:
206
+ h = h + self.time_layers[i](time_emb)[:, :, None]
207
+ h = self.norms[i]((h).permute(0, 2, 1))
208
+ h = F.relu(self.convs[i](h.permute(0, 2, 1)))
209
+ if h.shape == feat.shape:
210
+ feat = h + feat
211
+ else:
212
+ feat = h
213
+ feat = self.final_conv(feat)
214
+ feat = feat.permute(0, 2, 1)
215
+ return feat
models/peptide_classifiers.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ import pytorch_lightning as pl
6
+ import time
7
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
8
+ import xgboost as xgb
9
+ import esm
10
+
11
+ from flow_matching.path import MixtureDiscreteProbPath
12
+ from flow_matching.path.scheduler import PolynomialConvexScheduler
13
+ from flow_matching.solver import MixtureDiscreteEulerSolver
14
+ from flow_matching.utils import ModelWrapper
15
+ from flow_matching.loss import MixturePathGeneralizedKL
16
+
17
+ from models.peptide_models import CNNModel
18
+ from modules.bindevaluator_modules import *
19
+
20
+ def parse_motifs(motif: str) -> list:
21
+ parts = motif.split(',')
22
+ result = []
23
+
24
+ for part in parts:
25
+ part = part.strip()
26
+ if '-' in part:
27
+ start, end = map(int, part.split('-'))
28
+ result.extend(range(start, end + 1))
29
+ else:
30
+ result.append(int(part))
31
+
32
+ result = [pos-1 for pos in result]
33
+ print(f'Target Motifs: {result}')
34
+ return torch.tensor(result)
35
+
36
+ class BindEvaluator(pl.LightningModule):
37
+ def __init__(self, n_layers, d_model, d_hidden, n_head,
38
+ d_k, d_v, d_inner, dropout=0.2,
39
+ learning_rate=0.00001, max_epochs=15, kl_weight=1):
40
+ super(BindEvaluator, self).__init__()
41
+
42
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
43
+ self.esm_model.eval()
44
+ # freeze all the esm_model parameters
45
+ for param in self.esm_model.parameters():
46
+ param.requires_grad = False
47
+
48
+ self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
49
+ n_head, d_k, d_v, d_inner, dropout=dropout)
50
+
51
+ self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
52
+ d_k, d_v, dropout=dropout)
53
+
54
+ self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
55
+
56
+ self.output_projection_prot = nn.Linear(d_model, 1)
57
+
58
+ self.learning_rate = learning_rate
59
+ self.max_epochs = max_epochs
60
+ self.kl_weight = kl_weight
61
+
62
+ self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
63
+ self.historical_memory = 0.9
64
+ self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
65
+
66
+ def forward(self, binder_tokens, target_tokens):
67
+ peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
68
+ protein_sequence = self.esm_model(**target_tokens).last_hidden_state
69
+
70
+ prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
71
+ seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
72
+ protein_sequence)
73
+
74
+ prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
75
+
76
+ prot_enc = self.final_ffn(prot_enc)
77
+
78
+ prot_enc = self.output_projection_prot(prot_enc)
79
+
80
+ return prot_enc
81
+
82
+ def get_probs(self, x_t, target_sequence):
83
+ '''
84
+ Inputs:
85
+ - xt: Shape (bsz, seq_len)
86
+ - target_sequence: Shape (1, tgt_len)
87
+ '''
88
+ # pdb.set_trace()
89
+ target_sequence = target_sequence.repeat(x_t.shape[0], 1)
90
+ binder_attention_mask = torch.ones_like(x_t)
91
+ target_attention_mask = torch.ones_like(target_sequence)
92
+
93
+ binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0
94
+ target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0
95
+
96
+ binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)}
97
+ target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)}
98
+
99
+ logits = self.forward(binder_tokens, target_tokens).squeeze(-1)
100
+ # pdb.set_trace()
101
+ logits[:, 0] = logits[:, -1] = -100 # float('-inf')
102
+ probs = torch.sigmoid(logits)
103
+
104
+ return probs # shape (bsz, tgt_len)
105
+
106
+ def motif_score(self, x_t, target_sequence, motifs):
107
+ probs = self.get_probs(x_t, target_sequence)
108
+ motif_probs = probs[:, motifs]
109
+ motif_score = motif_probs.sum(dim=-1) / len(motifs)
110
+ # pdb.set_trace()
111
+ return motif_score
112
+
113
+ def non_motif_score(self, x_t, target_sequence, motifs):
114
+ probs = self.get_probs(x_t, target_sequence)
115
+ non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
116
+ mask = non_motif_probs >= 0.5
117
+ count = mask.sum(dim=-1)
118
+
119
+ non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
120
+
121
+ return non_motif_score
122
+
123
+ def scoring(self, x_t, target_sequence, motifs, penalty=False):
124
+ probs = self.get_probs(x_t, target_sequence)
125
+ motif_probs = probs[:, motifs]
126
+ motif_score = motif_probs.sum(dim=-1) / len(motifs)
127
+ # pdb.set_trace()
128
+
129
+ if penalty:
130
+ non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
131
+ mask = non_motif_probs >= 0.5
132
+ count = mask.sum(dim=-1)
133
+ # non_motif_score = 1 - torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
134
+ non_motif_score = count / target_sequence.shape[1]
135
+ return motif_score, 1 - non_motif_score
136
+ else:
137
+ return motif_score
138
+
139
+ class MotifModel(nn.Module):
140
+ def __init__(self, bindevaluator, target_sequence, motifs, penalty=False):
141
+ super(MotifModel, self).__init__()
142
+ self.bindevaluator = bindevaluator
143
+ self.target_sequence = target_sequence
144
+ self.motifs = motifs
145
+ self.penalty = penalty
146
+
147
+ def forward(self, x):
148
+ return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
149
+
150
+ class UnpooledBindingPredictor(nn.Module):
151
+ def __init__(self,
152
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
153
+ hidden_dim=512,
154
+ kernel_sizes=[3, 5, 7],
155
+ n_heads=8,
156
+ n_layers=3,
157
+ dropout=0.1,
158
+ freeze_esm=True):
159
+ super().__init__()
160
+
161
+ # Define binding thresholds
162
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
163
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
164
+
165
+ # Load ESM model for computing embeddings on the fly
166
+ self.esm_model = AutoModel.from_pretrained(esm_model_name)
167
+ self.config = AutoConfig.from_pretrained(esm_model_name)
168
+
169
+ # Freeze ESM parameters if needed
170
+ if freeze_esm:
171
+ for param in self.esm_model.parameters():
172
+ param.requires_grad = False
173
+
174
+ # Get ESM hidden size
175
+ esm_dim = self.config.hidden_size
176
+
177
+ # Output channels for CNN layers
178
+ output_channels_per_kernel = 64
179
+
180
+ # CNN layers for handling variable length sequences
181
+ self.protein_conv_layers = nn.ModuleList([
182
+ nn.Conv1d(
183
+ in_channels=esm_dim,
184
+ out_channels=output_channels_per_kernel,
185
+ kernel_size=k,
186
+ padding='same'
187
+ ) for k in kernel_sizes
188
+ ])
189
+
190
+ self.binder_conv_layers = nn.ModuleList([
191
+ nn.Conv1d(
192
+ in_channels=esm_dim,
193
+ out_channels=output_channels_per_kernel,
194
+ kernel_size=k,
195
+ padding='same'
196
+ ) for k in kernel_sizes
197
+ ])
198
+
199
+ # Calculate total features after convolution and pooling
200
+ total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
201
+
202
+ # Project to same dimension after CNN processing
203
+ self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
204
+ self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
205
+
206
+ self.protein_norm = nn.LayerNorm(hidden_dim)
207
+ self.binder_norm = nn.LayerNorm(hidden_dim)
208
+
209
+ # Cross attention blocks with layer norm
210
+ self.cross_attention_layers = nn.ModuleList([
211
+ nn.ModuleDict({
212
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
213
+ 'norm1': nn.LayerNorm(hidden_dim),
214
+ 'ffn': nn.Sequential(
215
+ nn.Linear(hidden_dim, hidden_dim * 4),
216
+ nn.ReLU(),
217
+ nn.Dropout(dropout),
218
+ nn.Linear(hidden_dim * 4, hidden_dim)
219
+ ),
220
+ 'norm2': nn.LayerNorm(hidden_dim)
221
+ }) for _ in range(n_layers)
222
+ ])
223
+
224
+ # Prediction heads
225
+ self.shared_head = nn.Sequential(
226
+ nn.Linear(hidden_dim * 2, hidden_dim),
227
+ nn.ReLU(),
228
+ nn.Dropout(dropout),
229
+ )
230
+
231
+ # Regression head
232
+ self.regression_head = nn.Linear(hidden_dim, 1)
233
+
234
+ # Classification head (3 classes: tight, medium, loose binding)
235
+ self.classification_head = nn.Linear(hidden_dim, 3)
236
+
237
+ def get_binding_class(self, affinity):
238
+ """Convert affinity values to class indices
239
+ 0: tight binding (>= 7.5)
240
+ 1: medium binding (6.0-7.5)
241
+ 2: weak binding (< 6.0)
242
+ """
243
+ if isinstance(affinity, torch.Tensor):
244
+ tight_mask = affinity >= self.tight_threshold
245
+ weak_mask = affinity < self.weak_threshold
246
+ medium_mask = ~(tight_mask | weak_mask)
247
+
248
+ classes = torch.zeros_like(affinity, dtype=torch.long)
249
+ classes[medium_mask] = 1
250
+ classes[weak_mask] = 2
251
+ return classes
252
+ else:
253
+ if affinity >= self.tight_threshold:
254
+ return 0 # tight binding
255
+ elif affinity < self.weak_threshold:
256
+ return 2 # weak binding
257
+ else:
258
+ return 1 # medium binding
259
+
260
+ def compute_embeddings(self, input_ids, attention_mask=None):
261
+ """Compute ESM embeddings on the fly"""
262
+ esm_outputs = self.esm_model(
263
+ input_ids=input_ids,
264
+ attention_mask=attention_mask,
265
+ return_dict=True
266
+ )
267
+
268
+ # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
269
+ return esm_outputs.last_hidden_state
270
+
271
+ def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
272
+ """Process a sequence through CNN layers and pooling"""
273
+ # Transpose for CNN: [batch_size, hidden_size, seq_length]
274
+ x = unpooled_emb.transpose(1, 2)
275
+
276
+ # Apply CNN layers and collect outputs
277
+ conv_outputs = []
278
+ for conv in conv_layers:
279
+ conv_out = F.relu(conv(x))
280
+ conv_outputs.append(conv_out)
281
+
282
+ # Concatenate along channel dimension
283
+ conv_output = torch.cat(conv_outputs, dim=1)
284
+
285
+ # Global pooling (both max and average)
286
+ # If attention mask is provided, use it to create a proper mask for pooling
287
+ if attention_mask is not None:
288
+ # Create a mask for pooling (1 for valid positions, 0 for padding)
289
+ # Expand mask to match conv_output channels
290
+ expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
291
+
292
+ # Apply mask (set padding to large negative value for max pooling)
293
+ masked_output = conv_output.clone()
294
+ masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
295
+
296
+ # Max pooling along sequence dimension
297
+ max_pooled = torch.max(masked_output, dim=2)[0]
298
+
299
+ # Average pooling (sum divided by number of valid positions)
300
+ sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
301
+ valid_positions = torch.sum(expanded_mask, dim=2)
302
+ valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
303
+ avg_pooled = sum_pooled / valid_positions
304
+ else:
305
+ # If no mask, use standard pooling
306
+ max_pooled = torch.max(conv_output, dim=2)[0]
307
+ avg_pooled = torch.mean(conv_output, dim=2)
308
+
309
+ # Concatenate the pooled features
310
+ pooled = torch.cat([max_pooled, avg_pooled], dim=1)
311
+
312
+ return pooled
313
+
314
+ def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
315
+ # Compute embeddings on the fly using the ESM model
316
+ protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
317
+ binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
318
+
319
+ # Process protein and binder sequences through CNN layers
320
+ protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
321
+ binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
322
+
323
+ # Project to same dimension
324
+ protein = self.protein_norm(self.protein_projection(protein_features))
325
+ binder = self.binder_norm(self.binder_projection(binder_features))
326
+
327
+ # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
328
+ protein = protein.unsqueeze(0)
329
+ binder = binder.unsqueeze(0)
330
+
331
+ # Cross attention layers
332
+ for layer in self.cross_attention_layers:
333
+ # Protein attending to binder
334
+ attended_protein = layer['attention'](
335
+ protein, binder, binder
336
+ )[0]
337
+ protein = layer['norm1'](protein + attended_protein)
338
+ protein = layer['norm2'](protein + layer['ffn'](protein))
339
+
340
+ # Binder attending to protein
341
+ attended_binder = layer['attention'](
342
+ binder, protein, protein
343
+ )[0]
344
+ binder = layer['norm1'](binder + attended_binder)
345
+ binder = layer['norm2'](binder + layer['ffn'](binder))
346
+
347
+ # Remove sequence dimension
348
+ protein_pool = protein.squeeze(0)
349
+ binder_pool = binder.squeeze(0)
350
+
351
+ # Concatenate both representations
352
+ combined = torch.cat([protein_pool, binder_pool], dim=-1)
353
+
354
+ # Shared features
355
+ shared_features = self.shared_head(combined)
356
+
357
+ regression_output = self.regression_head(shared_features)
358
+ # classification_logits = self.classification_head(shared_features)
359
+
360
+ # return regression_output, classification_logits
361
+ return regression_output
362
+
363
+ class ImprovedBindingPredictor(nn.Module):
364
+ def __init__(self,
365
+ esm_dim=1280,
366
+ smiles_dim=1280,
367
+ hidden_dim=512,
368
+ n_heads=8,
369
+ n_layers=5,
370
+ dropout=0.1):
371
+ super().__init__()
372
+
373
+ # Define binding thresholds
374
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
375
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
376
+
377
+ # Project to same dimension
378
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
379
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
380
+ self.protein_norm = nn.LayerNorm(hidden_dim)
381
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
382
+
383
+ # Cross attention blocks with layer norm
384
+ self.cross_attention_layers = nn.ModuleList([
385
+ nn.ModuleDict({
386
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
387
+ 'norm1': nn.LayerNorm(hidden_dim),
388
+ 'ffn': nn.Sequential(
389
+ nn.Linear(hidden_dim, hidden_dim * 4),
390
+ nn.ReLU(),
391
+ nn.Dropout(dropout),
392
+ nn.Linear(hidden_dim * 4, hidden_dim)
393
+ ),
394
+ 'norm2': nn.LayerNorm(hidden_dim)
395
+ }) for _ in range(n_layers)
396
+ ])
397
+
398
+ # Prediction heads
399
+ self.shared_head = nn.Sequential(
400
+ nn.Linear(hidden_dim * 2, hidden_dim),
401
+ nn.ReLU(),
402
+ nn.Dropout(dropout),
403
+ )
404
+
405
+ # Regression head
406
+ self.regression_head = nn.Linear(hidden_dim, 1)
407
+
408
+ # Classification head (3 classes: tight, medium, loose binding)
409
+ self.classification_head = nn.Linear(hidden_dim, 3)
410
+
411
+ def get_binding_class(self, affinity):
412
+ """Convert affinity values to class indices
413
+ 0: tight binding (>= 7.5)
414
+ 1: medium binding (6.0-7.5)
415
+ 2: weak binding (< 6.0)
416
+ """
417
+ if isinstance(affinity, torch.Tensor):
418
+ tight_mask = affinity >= self.tight_threshold
419
+ weak_mask = affinity < self.weak_threshold
420
+ medium_mask = ~(tight_mask | weak_mask)
421
+
422
+ classes = torch.zeros_like(affinity, dtype=torch.long)
423
+ classes[medium_mask] = 1
424
+ classes[weak_mask] = 2
425
+ return classes
426
+ else:
427
+ if affinity >= self.tight_threshold:
428
+ return 0 # tight binding
429
+ elif affinity < self.weak_threshold:
430
+ return 2 # weak binding
431
+ else:
432
+ return 1 # medium binding
433
+
434
+ def forward(self, protein_emb, binder_emb):
435
+
436
+ protein = self.protein_norm(self.protein_projection(protein_emb))
437
+ smiles = self.smiles_norm(self.smiles_projection(binder_emb))
438
+
439
+ protein = protein.transpose(0, 1)
440
+ smiles = smiles.transpose(0, 1)
441
+
442
+ # Cross attention layers
443
+ for layer in self.cross_attention_layers:
444
+ # Protein attending to SMILES
445
+ attended_protein = layer['attention'](
446
+ protein, smiles, smiles
447
+ )[0]
448
+ protein = layer['norm1'](protein + attended_protein)
449
+ protein = layer['norm2'](protein + layer['ffn'](protein))
450
+
451
+ # SMILES attending to protein
452
+ attended_smiles = layer['attention'](
453
+ smiles, protein, protein
454
+ )[0]
455
+ smiles = layer['norm1'](smiles + attended_smiles)
456
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
457
+
458
+ # Get sequence-level representations
459
+ protein_pool = torch.mean(protein, dim=0)
460
+ smiles_pool = torch.mean(smiles, dim=0)
461
+
462
+ # Concatenate both representations
463
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
464
+
465
+ # Shared features
466
+ shared_features = self.shared_head(combined)
467
+
468
+ regression_output = self.regression_head(shared_features)
469
+
470
+ return regression_output
471
+
472
+ class PooledAffinityModel(nn.Module):
473
+ def __init__(self, affinity_predictor, target_sequence):
474
+ super(PooledAffinityModel, self).__init__()
475
+ self.affinity_predictor = affinity_predictor
476
+ self.target_sequence = target_sequence
477
+ self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device)
478
+ for param in self.esm_model.parameters():
479
+ param.requires_grad = False
480
+
481
+ def compute_embeddings(self, input_ids, attention_mask=None):
482
+ """Compute ESM embeddings on the fly"""
483
+ esm_outputs = self.esm_model(
484
+ input_ids=input_ids,
485
+ attention_mask=attention_mask,
486
+ return_dict=True
487
+ )
488
+
489
+ # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
490
+ return esm_outputs.last_hidden_state
491
+
492
+ def forward(self, x):
493
+ target_sequence = self.target_sequence.repeat(x.shape[0], 1)
494
+
495
+ protein_emb = self.compute_embeddings(input_ids=target_sequence)
496
+ binder_emb = self.compute_embeddings(input_ids=x)
497
+ return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1)
498
+
499
+ class AffinityModel(nn.Module):
500
+ def __init__(self, affinity_predictor, target_sequence):
501
+ super(AffinityModel, self).__init__()
502
+ self.affinity_predictor = affinity_predictor
503
+ self.target_sequence = target_sequence
504
+
505
+ def forward(self, x):
506
+ target_sequence = self.target_sequence.repeat(x.shape[0], 1)
507
+ affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1)
508
+ return affinity / 10
509
+
510
+ class HemolysisModel:
511
+ def __init__(self, device):
512
+ self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_hemolysis.json')
513
+
514
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
515
+ self.model.eval()
516
+
517
+ self.device = device
518
+
519
+ def generate_embeddings(self, sequences):
520
+ """Generate ESM embeddings for protein sequences"""
521
+ with torch.no_grad():
522
+ embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
523
+ embeddings = embeddings.cpu().numpy()
524
+
525
+ return embeddings
526
+
527
+ def get_scores(self, input_seqs):
528
+ scores = np.ones(len(input_seqs))
529
+ features = self.generate_embeddings(input_seqs)
530
+
531
+ if len(features) == 0:
532
+ return scores
533
+
534
+ features = np.nan_to_num(features, nan=0.)
535
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
536
+
537
+ features = xgb.DMatrix(features)
538
+
539
+ probs = self.predictor.predict(features)
540
+ # return the probability of it being not hemolytic
541
+ return torch.from_numpy(scores - probs).to(self.device)
542
+
543
+ def __call__(self, input_seqs: list):
544
+ scores = self.get_scores(input_seqs)
545
+ return scores
546
+
547
+ class NonfoulingModel:
548
+ def __init__(self, device):
549
+ # change model path
550
+ self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_nonfouling.json')
551
+
552
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
553
+ self.model.eval()
554
+
555
+ self.device = device
556
+
557
+ def generate_embeddings(self, sequences):
558
+ """Generate ESM embeddings for protein sequences"""
559
+ with torch.no_grad():
560
+ embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
561
+ embeddings = embeddings.cpu().numpy()
562
+
563
+ return embeddings
564
+
565
+ def get_scores(self, input_seqs):
566
+ scores = np.zeros(len(input_seqs))
567
+ features = self.generate_embeddings(input_seqs)
568
+
569
+ if len(features) == 0:
570
+ return scores
571
+
572
+ features = np.nan_to_num(features, nan=0.)
573
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
574
+
575
+ features = xgb.DMatrix(features)
576
+
577
+ scores = self.predictor.predict(features)
578
+ return torch.from_numpy(scores).to(self.device)
579
+
580
+ def __call__(self, input_seqs: list):
581
+ scores = self.get_scores(input_seqs)
582
+ return scores
583
+
584
+ class SolubilityModel:
585
+ def __init__(self, device):
586
+ # change model path
587
+ self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json')
588
+
589
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
590
+ self.model.eval()
591
+
592
+ self.device = device
593
+
594
+ def generate_embeddings(self, sequences):
595
+ """Generate ESM embeddings for protein sequences"""
596
+ with torch.no_grad():
597
+ embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
598
+ embeddings = embeddings.cpu().numpy()
599
+
600
+ return embeddings
601
+
602
+ def get_scores(self, input_seqs: list):
603
+ scores = np.zeros(len(input_seqs))
604
+ features = self.generate_embeddings(input_seqs)
605
+
606
+ if len(features) == 0:
607
+ return scores
608
+
609
+ features = np.nan_to_num(features, nan=0.)
610
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
611
+
612
+ features = xgb.DMatrix(features)
613
+
614
+ scores = self.predictor.predict(features)
615
+ return torch.from_numpy(scores).to(self.device)
616
+
617
+ def __call__(self, input_seqs: list):
618
+ scores = self.get_scores(input_seqs)
619
+ return scores
620
+
621
+ class SolubilityModelNew:
622
+ def __init__(self, device):
623
+ self.hydro_ids = torch.tensor([5, 7, 4, 12, 20, 18, 22, 14], device=device)
624
+ self.device = device
625
+
626
+ def get_scores(self, x):
627
+ mask = (x.unsqueeze(-1) == self.hydro_ids).any(dim=-1)
628
+ ratios = mask.float().mean(dim=1)
629
+ return 1 - ratios
630
+
631
+ def __call__(self, input_seqs: list):
632
+ scores = self.get_scores(input_seqs)
633
+ return scores
634
+
635
+ class PeptideCNN(nn.Module):
636
+ def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
637
+ super().__init__()
638
+ self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
639
+ self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
640
+ self.fc = nn.Linear(hidden_dims[1], output_dim)
641
+ self.dropout = nn.Dropout(dropout_rate)
642
+ self.predictor = nn.Linear(output_dim, 1) # For regression/classification
643
+
644
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
645
+ self.esm_model.eval()
646
+
647
+ def forward(self, input_ids, attention_mask=None, return_features=False):
648
+ with torch.no_grad():
649
+ x = self.esm_model(input_ids, attention_mask).last_hidden_state
650
+ # x shape: (B, L, input_dim)
651
+ x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
652
+ x = nn.functional.relu(self.conv1(x))
653
+ x = self.dropout(x)
654
+ x = nn.functional.relu(self.conv2(x))
655
+ x = self.dropout(x)
656
+ x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
657
+
658
+ # Global average pooling over the sequence dimension (L)
659
+ x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
660
+
661
+ features = self.fc(x) # features shape: (B, output_dim)
662
+ if return_features:
663
+ return features
664
+ return self.predictor(features) # Output shape: (B, 1)
665
+
666
+ class HalfLifeModel:
667
+ def __init__(self, device):
668
+ input_dim = 1280
669
+ hidden_dims = [input_dim // 2, input_dim // 4]
670
+ output_dim = input_dim // 8
671
+ dropout_rate = 0.3
672
+ self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
673
+ self.model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
674
+ self.model.eval()
675
+
676
+ def __call__(self, x):
677
+ prediction = self.model(x, return_features=False)
678
+ halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
679
+ return halflife / 2
680
+
681
+
682
+ def load_bindevaluator(checkpoint_path, device):
683
+ bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
684
+ bindevaluator.eval()
685
+ for param in bindevaluator.parameters():
686
+ param.requires_grad = False
687
+
688
+ return bindevaluator
689
+
690
+
691
+ def load_solver(checkpoint_path, vocab_size, device):
692
+ lr = 1e-4
693
+ epochs = 200
694
+ embed_dim = 512
695
+ hidden_dim = 256
696
+ epsilon = 1e-3
697
+ batch_size = 256
698
+ warmup_epochs = epochs // 10
699
+ device = 'cuda:0'
700
+
701
+
702
+ probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device)
703
+ probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))
704
+ probability_denoiser.eval()
705
+ for param in probability_denoiser.parameters():
706
+ param.requires_grad = False
707
+
708
+ # instantiate a convex path object
709
+ scheduler = PolynomialConvexScheduler(n=2.0)
710
+ path = MixtureDiscreteProbPath(scheduler=scheduler)
711
+
712
+ class WrappedModel(ModelWrapper):
713
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
714
+ return torch.softmax(self.model(x, t), dim=-1)
715
+
716
+ wrapped_probability_denoiser = WrappedModel(probability_denoiser)
717
+ solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
718
+
719
+ return solver
720
+
721
+
722
+ def load_pooled_affinity_predictor(checkpoint_path, device):
723
+ """Load trained model from checkpoint."""
724
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
725
+
726
+ model = ImprovedBindingPredictor().to(device)
727
+
728
+ # Load the trained weights
729
+ model.load_state_dict(checkpoint['model_state_dict'])
730
+ model.eval() # Set to evaluation mode
731
+
732
+ return model
733
+
734
+ def load_affinity_predictor(checkpoint_path, device):
735
+ """Load trained model from checkpoint."""
736
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
737
+
738
+ model = UnpooledBindingPredictor(
739
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
740
+ hidden_dim=384,
741
+ kernel_sizes=[3, 5, 7],
742
+ n_heads=8,
743
+ n_layers=4,
744
+ dropout=0.14561457009902096,
745
+ freeze_esm=True
746
+ ).to(device)
747
+
748
+ model.load_state_dict(checkpoint['model_state_dict'])
749
+ model.eval()
750
+
751
+ return model
models/peptide_models.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModel
5
+ import torch.nn.functional as F
6
+ import esm
7
+ import copy
8
+ import pdb
9
+
10
+ class GaussianFourierProjection(nn.Module):
11
+ """
12
+ Gaussian random features for encoding time steps.
13
+ """
14
+
15
+ def __init__(self, embed_dim, scale=30.):
16
+ super().__init__()
17
+ # Randomly sample weights during initialization. These weights are fixed
18
+ # during optimization and are not trainable.
19
+ self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
20
+
21
+ def forward(self, x):
22
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
23
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
24
+
25
+ class Dense(nn.Module):
26
+ """
27
+ A fully connected layer that reshapes outputs to feature maps.
28
+ """
29
+
30
+ def __init__(self, input_dim, output_dim):
31
+ super().__init__()
32
+ self.dense = nn.Linear(input_dim, output_dim)
33
+
34
+ def forward(self, x):
35
+ return self.dense(x)[...]
36
+
37
+ class Swish(nn.Module):
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def forward(self, x):
42
+ return torch.sigmoid(x) * x
43
+
44
+ class CNNESMModel(nn.Module):
45
+ """A time-dependent score-based model built upon U-Net architecture."""
46
+
47
+ def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
48
+ """
49
+ Args:
50
+ embed_dim (int): Dimensionality of the token and time embeddings.
51
+ """
52
+ super().__init__()
53
+ self.alphabet_size = alphabet_size
54
+
55
+ # self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
56
+ self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
57
+ self.esm.eval()
58
+ for param in self.esm.parameters():
59
+ param.requires_grad = False
60
+
61
+ self.time_embed = nn.Sequential(
62
+ GaussianFourierProjection(embed_dim=embed_dim),
63
+ nn.Linear(embed_dim, embed_dim)
64
+ )
65
+
66
+ self.swish = Swish()
67
+
68
+ n = hidden_dim
69
+
70
+ self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
71
+
72
+ self.blocks = nn.ModuleList([
73
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
74
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
75
+ nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
76
+ nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
77
+ nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
78
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
79
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
80
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
81
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
82
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
83
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
84
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
85
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
86
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
87
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
88
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
89
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
90
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
91
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
92
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
93
+ ])
94
+
95
+ self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)])
96
+ self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)])
97
+
98
+ self.final = nn.Sequential(
99
+ nn.Conv1d(n, n, kernel_size=1),
100
+ nn.GELU(),
101
+ nn.Conv1d(n, self.alphabet_size, kernel_size=1)
102
+ )
103
+
104
+
105
+ def forward(self, x, t):
106
+ """
107
+ Args:
108
+ x: Tensor of shape (B, L) containing DNA token indices.
109
+ t: Tensor of shape (B,) containing the time steps.
110
+ Returns:
111
+ out: Tensor of shape (B, L, 4) with output logits for each DNA base.
112
+ """
113
+ # x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
114
+ with torch.no_grad():
115
+ x = self.esm(input_ids=x).last_hidden_state
116
+ time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
117
+
118
+ out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
119
+ out = self.swish(self.linear(out)) # (B, n, L)
120
+
121
+ # Process through convolutional blocks, adding time conditioning via dense layers.
122
+ for block, dense, norm in zip(self.blocks, self.denses, self.norms):
123
+ # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
124
+ h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
125
+ # Residual connection if shapes match.
126
+ if h.shape == out.shape:
127
+ out = h + out
128
+ else:
129
+ out = h
130
+
131
+ out = self.final(out) # (B, 4, L)
132
+ out = out.permute(0, 2, 1) # (B, L, 4)
133
+
134
+ # Normalization
135
+ out = out - out.mean(dim=-1, keepdim=True)
136
+ return out
137
+
138
+
139
+ class MLPModel(nn.Module):
140
+ def __init__(
141
+ self, input_dim: int = 128, time_dim: int = 1, hidden_dim=128, length=500):
142
+ super().__init__()
143
+ self.input_dim = input_dim
144
+ self.time_dim = time_dim
145
+ self.hidden_dim = hidden_dim
146
+
147
+ self.time_embedding = nn.Linear(1, time_dim)
148
+ self.token_embedding = torch.nn.Embedding(self.input_dim, hidden_dim)
149
+
150
+ self.swish = Swish()
151
+
152
+ self.main = nn.Sequential(
153
+ self.swish,
154
+ nn.Linear(hidden_dim * length + time_dim, hidden_dim),
155
+ self.swish,
156
+ nn.Linear(hidden_dim, hidden_dim),
157
+ self.swish,
158
+ nn.Linear(hidden_dim, hidden_dim),
159
+ self.swish,
160
+ nn.Linear(hidden_dim, self.input_dim * length),
161
+ )
162
+
163
+ def forward(self, x, t):
164
+ '''
165
+ x shape (B,L)
166
+ t shape (B,)
167
+ '''
168
+ t = self.time_embedding(t.unsqueeze(-1))
169
+ x = self.token_embedding(x)
170
+
171
+ B, N, d = x.shape
172
+ x = x.reshape(B, N * d)
173
+
174
+ h = torch.cat([x, t], dim=1)
175
+ h = self.main(h)
176
+
177
+ h = h.reshape(B, N, self.input_dim)
178
+
179
+ return h
180
+
181
+ class CNNModel(nn.Module):
182
+ """A time-dependent score-based model built upon U-Net architecture."""
183
+
184
+ def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
185
+ """
186
+ Args:
187
+ embed_dim (int): Dimensionality of the token and time embeddings.
188
+ """
189
+ super().__init__()
190
+ self.alphabet_size = alphabet_size
191
+
192
+ self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
193
+ # self.esm = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
194
+ # self.esm.eval()
195
+ # for param in self.esm.parameters():
196
+ # param.requires_grad = False
197
+
198
+ self.time_embed = nn.Sequential(
199
+ GaussianFourierProjection(embed_dim=embed_dim),
200
+ nn.Linear(embed_dim, embed_dim)
201
+ )
202
+
203
+ self.swish = Swish()
204
+
205
+ n = hidden_dim
206
+
207
+ self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
208
+
209
+ self.blocks = nn.ModuleList([
210
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
211
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
212
+ nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
213
+ nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
214
+ nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
215
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
216
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
217
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
218
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
219
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
220
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
221
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
222
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
223
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
224
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
225
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
226
+ # nn.Conv1d(n, n, kernel_size=9, padding=4),
227
+ # nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
228
+ # nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
229
+ # nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
230
+ ])
231
+
232
+ self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(5)])
233
+ self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(5)])
234
+
235
+ self.final = nn.Sequential(
236
+ nn.Conv1d(n, n, kernel_size=1),
237
+ nn.GELU(),
238
+ nn.Conv1d(n, self.alphabet_size, kernel_size=1)
239
+ )
240
+
241
+ def forward(self, x, t):
242
+ """
243
+ Args:
244
+ x: Tensor of shape (B, L) containing DNA token indices.
245
+ t: Tensor of shape (B,) containing the time steps.
246
+ Returns:
247
+ out: Tensor of shape (B, L, 4) with output logits for each DNA base.
248
+ """
249
+ x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
250
+ # with torch.no_grad():
251
+ # x = self.esm(input_ids=x).last_hidden_state
252
+ time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
253
+
254
+ out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
255
+ out = self.swish(self.linear(out)) # (B, n, L)
256
+
257
+ # Process through convolutional blocks, adding time conditioning via dense layers.
258
+ for block, dense, norm in zip(self.blocks, self.denses, self.norms):
259
+ # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
260
+ h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
261
+ # Residual connection if shapes match.
262
+ if h.shape == out.shape:
263
+ out = h + out
264
+ else:
265
+ out = h
266
+
267
+ out = self.final(out) # (B, 4, L)
268
+ out = out.permute(0, 2, 1) # (B, L, 4)
269
+
270
+ # Normalization
271
+ out = out - out.mean(dim=-1, keepdim=True)
272
+ return out
273
+
274
+ class CNNModel_Large(nn.Module):
275
+ """A time-dependent score-based model built upon U-Net architecture."""
276
+
277
+ def __init__(self, alphabet_size=4, embed_dim=256, hidden_dim=256):
278
+ """
279
+ Args:
280
+ embed_dim (int): Dimensionality of the token and time embeddings.
281
+ """
282
+ super().__init__()
283
+ self.alphabet_size = alphabet_size
284
+
285
+ self.token_embedding = nn.Embedding(self.alphabet_size, embed_dim)
286
+
287
+ self.time_embed = nn.Sequential(
288
+ GaussianFourierProjection(embed_dim=embed_dim),
289
+ nn.Linear(embed_dim, embed_dim)
290
+ )
291
+
292
+ self.swish = Swish()
293
+
294
+ n = hidden_dim
295
+
296
+ self.linear = nn.Conv1d(embed_dim, n, kernel_size=9, padding=4)
297
+
298
+ self.blocks = nn.ModuleList([
299
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
300
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
301
+ nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
302
+ nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
303
+ nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
304
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
305
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
306
+ nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
307
+ nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
308
+ nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
309
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
310
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
311
+ nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
312
+ nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
313
+ nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256),
314
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
315
+ nn.Conv1d(n, n, kernel_size=9, padding=4),
316
+ nn.Conv1d(n, n, kernel_size=9, dilation=4, padding=16),
317
+ nn.Conv1d(n, n, kernel_size=9, dilation=16, padding=64),
318
+ nn.Conv1d(n, n, kernel_size=9, dilation=64, padding=256)
319
+ ])
320
+
321
+ self.denses = nn.ModuleList([Dense(embed_dim, n) for _ in range(20)])
322
+ self.norms = nn.ModuleList([nn.GroupNorm(1, n) for _ in range(20)])
323
+
324
+ self.final = nn.Sequential(
325
+ nn.Conv1d(n, n, kernel_size=1),
326
+ nn.GELU(),
327
+ nn.Conv1d(n, self.alphabet_size, kernel_size=1)
328
+ )
329
+
330
+ def forward(self, x, t):
331
+ """
332
+ Args:
333
+ x: Tensor of shape (B, L) containing DNA token indices.
334
+ t: Tensor of shape (B,) containing the time steps.
335
+ Returns:
336
+ out: Tensor of shape (B, L, 4) with output logits for each DNA base.
337
+ """
338
+ x = self.token_embedding(x) # (B, L) -> (B, L, embed_dim)
339
+ time_embed = self.swish(self.time_embed(t)) # (B, embed_dim)
340
+
341
+ out = x.permute(0, 2, 1) # (B, L, embed_dim) -> (B, embed_dim, L)
342
+ out = self.swish(self.linear(out)) # (B, n, L)
343
+
344
+ # Process through convolutional blocks, adding time conditioning via dense layers.
345
+ for block, dense, norm in zip(self.blocks, self.denses, self.norms):
346
+ # dense(embed) gives (B, n); unsqueeze to (B, n, 1) for broadcasting.
347
+ h = self.swish(block(norm(out + dense(time_embed)[:, :, None])))
348
+ # Residual connection if shapes match.
349
+ if h.shape == out.shape:
350
+ out = h + out
351
+ else:
352
+ out = h
353
+
354
+ out = self.final(out) # (B, 4, L)
355
+ out = out.permute(0, 2, 1) # (B, L, 4)
356
+
357
+ # Normalization
358
+ out = out - out.mean(dim=-1, keepdim=True)
359
+ return out
modules/bindevaluator_modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import *
2
+ from .score_domain import *
3
+ from .dataloaders import *
modules/bindevaluator_modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (234 Bytes). View file
 
modules/bindevaluator_modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (228 Bytes). View file
 
modules/bindevaluator_modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (234 Bytes). View file
 
modules/bindevaluator_modules/__pycache__/dataloaders.cpython-310.pyc ADDED
Binary file (7.93 kB). View file
 
modules/bindevaluator_modules/__pycache__/dataloaders.cpython-38.pyc ADDED
Binary file (8.44 kB). View file
 
modules/bindevaluator_modules/__pycache__/dataloaders.cpython-39.pyc ADDED
Binary file (8.59 kB). View file
 
modules/bindevaluator_modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (3.58 kB). View file
 
modules/bindevaluator_modules/__pycache__/layers.cpython-38.pyc ADDED
Binary file (3.68 kB). View file