Niksa Praljak
commited on
Commit
·
666e0ff
1
Parent(s):
66d2e5f
Update facilitator section README.md
Browse files
README.md
CHANGED
@@ -128,13 +128,99 @@ tensor([[1.0000, 0.1840],
|
|
128 |
|
129 |
## Stage 2: Facilitator Sampling
|
130 |
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
This stage will contain scripts and models for the Facilitator Sampling process. Check back for:
|
134 |
-
- Configuration files
|
135 |
-
- Model weights
|
136 |
-
- Running instructions
|
137 |
-
- Output examples
|
138 |
|
139 |
## Stage 3: ProteoScribe
|
140 |
|
|
|
128 |
|
129 |
## Stage 2: Facilitator Sampling
|
130 |
|
131 |
+
### Overview
|
132 |
+
|
133 |
+
In this stage, the **Facilitator model** takes the text embeddings (z_t) computed in Stage 1 and generates **facilitated embeddings (z_c)**. The facilitated embeddings align more closely with protein embeddings (z_p) and reduce discrepancies, as demonstrated by **Mean Squared Error (MSE)** and **Maximum Mean Discrepancy (MMD)** metrics.
|
134 |
+
|
135 |
+
### Model Weights
|
136 |
+
|
137 |
+
Before running the model, ensure you have:
|
138 |
+
- Configuration file: `stage2_facilitator_config.json`
|
139 |
+
- Pre-trained weights: `BioM3_Facilitator_epoch20.bin`
|
140 |
+
|
141 |
+
### Running the Facilitator Model
|
142 |
+
|
143 |
+
1. Clone the repository:
|
144 |
+
```bash
|
145 |
+
git clone https://huggingface.co/your_username/BioM3_Facilitator
|
146 |
+
cd BioM3_Facilitator
|
147 |
+
```
|
148 |
+
|
149 |
+
2. Run inference:
|
150 |
+
```bash
|
151 |
+
python run_facilitator_inference.py \
|
152 |
+
--json_path "stage2_facilitator_config.json" \
|
153 |
+
--model_path "./weights/Facilitator/BioM3_Facilitator_epoch20.bin" \
|
154 |
+
--input_data_path "outputs/Stage1_test_prompts_PDZ.pt" \
|
155 |
+
--output_data_path "outputs/Stage2_test_prompts_PDZ.pt"
|
156 |
+
```
|
157 |
+
|
158 |
+
Arguments:
|
159 |
+
- **json_path**: Path to the JSON configuration file
|
160 |
+
- **model_path**: Path to the pre-trained facilitator weights
|
161 |
+
- **input_data_path**: Path to the input embeddings (z_t and z_p) generated in Stage 1
|
162 |
+
- **output_data_path**: Path to save the facilitated embeddings (z_c)
|
163 |
+
|
164 |
+
### Expected Output
|
165 |
+
|
166 |
+
The script provides the following outputs:
|
167 |
+
|
168 |
+
1. **Latent Embedding Shapes**
|
169 |
+
- z_t: Text embeddings
|
170 |
+
- z_p: Protein embeddings
|
171 |
+
- z_c: Facilitated embeddings
|
172 |
+
|
173 |
+
2. **Vector Magnitudes**
|
174 |
+
- L2 norms of z_t, z_p, and z_c for a given batch
|
175 |
+
|
176 |
+
3. **Mean Squared Error (MSE)**
|
177 |
+
- MSE between facilitated embeddings (z_c) and protein embeddings (z_p)
|
178 |
+
- MSE between text embeddings (z_t) and protein embeddings (z_p)
|
179 |
+
|
180 |
+
4. **Maximum Mean Discrepancy (MMD)**
|
181 |
+
- MMD between facilitated embeddings (z_c) and protein embeddings (z_p)
|
182 |
+
- MMD between text embeddings (z_t) and protein embeddings (z_p)
|
183 |
+
|
184 |
+
### Sample Output
|
185 |
+
|
186 |
+
```plaintext
|
187 |
+
=== Facilitator Model Output ===
|
188 |
+
Shape of z_t (Text Embeddings): torch.Size([2, 512])
|
189 |
+
Shape of z_p (Protein Embeddings): torch.Size([2, 512])
|
190 |
+
Shape of z_c (Facilitated Embeddings): torch.Size([2, 512])
|
191 |
+
|
192 |
+
=== Norm (L2 Magnitude) Results for Batch Index 0 ===
|
193 |
+
Norm of z_t (Text Embedding): 29.697054
|
194 |
+
Norm of z_p (Protein Embedding): 5.337610
|
195 |
+
Norm of z_c (Facilitated Embedding): 3.244318
|
196 |
+
|
197 |
+
=== Mean Squared Error (MSE) Results ===
|
198 |
+
MSE between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): 0.069909
|
199 |
+
MSE between Text Embeddings (z_t) and Protein Embeddings (z_p): 1.612812
|
200 |
+
|
201 |
+
=== Max Mean Discrepancy (MMD) Results ===
|
202 |
+
MMD between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): 0.000171
|
203 |
+
MMD between Text Embeddings (z_t) and Protein Embeddings (z_p): 0.005172
|
204 |
+
```
|
205 |
+
|
206 |
+
### What the Output Means
|
207 |
+
|
208 |
+
1. **Latent Shapes**:
|
209 |
+
- Ensures that z_c has the same shape as z_p and z_t
|
210 |
+
|
211 |
+
2. **Norms**:
|
212 |
+
- z_c is closer in magnitude to z_p compared to z_t, showing that the facilitator model effectively aligns the embeddings
|
213 |
+
|
214 |
+
3. **MSE**:
|
215 |
+
- Lower MSE for z_c and z_p compared to z_t and z_p confirms that z_c approximates z_p better
|
216 |
+
|
217 |
+
4. **MMD**:
|
218 |
+
- The MMD loss shows that the **distribution** of z_c is closer to z_p than the original z_t
|
219 |
+
|
220 |
+
### Saving the Output
|
221 |
+
|
222 |
+
The facilitated embeddings are saved to the specified output_data_path for further stages.
|
223 |
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
## Stage 3: ProteoScribe
|
226 |
|