wi-lab commited on
Commit
197379c
·
verified ·
1 Parent(s): e62ae84

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +484 -23
README.md CHANGED
@@ -62,75 +62,53 @@ To handle variable-sized inputs efficiently, we implemented **bucket-based batch
62
 
63
  This approach eliminates the rigidity of fixed channel sizes and positions LWM-v1.1 as a versatile model capable of adapting to real-world wireless systems with varying configurations.
64
 
65
- ---
66
-
67
  ### **Larger and More Diverse Pretraining Dataset**
68
  Generalization is a critical aspect of any foundation model. In **LWM-v1.1**, we significantly expanded the training dataset to cover more diverse scenarios and environments. We added **seven new city scenarios**—Charlotte, Denver, Oklahoma, Indianapolis, Fort Worth, Santa Clara, and San Diego—to enrich the model’s exposure to a variety of urban layouts. To enhance the spatial resolution of the training data, we reduced the grid spacing between user locations in the DeepMIMO city scenarios from **2.5m to 1m**, resulting in a higher density of user positions. This adjustment required re-performing ray tracing for all scenarios to generate high-resolution wireless channel data.
69
 
70
  Additionally, we introduced **channels from multiple base stations** in each scenario, with distinct (N, SC) pairs to ensure the model encounters a broad range of channel characteristics. This diversity mirrors the variability found in real-world deployments, such as urban, suburban, and rural environments. By exposing LWM-v1.1 to this diversity, the model gains the ability to generalize across environments with distinct propagation characteristics, making it more reliable and versatile.
71
 
72
- ---
73
-
74
  ### **Fine-Tuning for Task-Specific Embedding Generation**
75
  While pretraining provides a robust feature extractor, downstream tasks often require tailored embeddings. In **LWM-v1.1**, we introduced **fine-tuning options** that give users the flexibility to customize the model for specific tasks. Users can now **freeze specific layers** of the model, allowing the remaining layers to adapt to task-specific requirements. This feature is particularly valuable for tasks prone to overfitting, such as **LoS/NLoS classification**, where excessive training on all layers can lead to suboptimal generalization.
76
 
77
  To further streamline task-specific adaptation, we provided **default classification and regression heads** for downstream tasks. Users can also define their own custom heads to suit unique requirements, ensuring maximum flexibility and adaptability.
78
 
79
- ---
80
-
81
  ### **Increased Model Capacity**
82
  LWM-v1.1 significantly enhances the model's ability to extract complex features by increasing the **embedding size from 64 to 128**. This increase more than quadruples the model's parameter count, raising it from **600K to 2.5M**. The larger embedding size allows the model to represent more intricate relationships within channel data, improving its performance on challenging tasks such as **beam prediction** and **channel estimation**.
83
 
84
  This change directly impacts the quality of the embeddings, making them more expressive and robust across a variety of downstream tasks, even in scenarios with limited labeled data.
85
 
86
- ---
87
-
88
  ### **Challenging MCM Task with Higher Masking Ratio**
89
  The **Masked Channel Modeling (MCM)** task lies at the core of LWM’s pretraining methodology. In **LWM-v1.1**, we made the task more challenging by increasing the **masking ratio from 15% to 40%**. This means that a larger portion of the channel data is masked during training, requiring the model to infer the missing information from contextual dependencies.
90
 
91
  This enhancement forces the model to rely on deeper spatial relationships between antennas and subcarriers, rather than learning superficial patterns. As a result, LWM-v1.1 produces embeddings that are more robust and better equipped to handle real-world scenarios with incomplete or noisy data.
92
 
93
- ---
94
-
95
  ### **Support for Larger Input Sizes**
96
  Wireless communication systems are increasingly handling larger channels with higher dimensions. To accommodate these demands, we increased the **maximum sequence length** from **128 to 512** in **LWM-v1.1**. This change enables the model to process larger and more detailed channel data without modification, broadening its applicability to high-dimensional wireless tasks. This ensures that LWM-v1.1 remains relevant as the scale and complexity of wireless systems continue to grow.
97
 
98
- ---
99
-
100
  ### **2D Patch Segmentation for Realistic Learning**
101
  In **LWM-v1.0**, patches were segmented based on a single dimension, typically grouping elements from different subcarriers within the same antenna. In **LWM-v1.1**, we introduced **2D patch segmentation**, where patches now combine elements from both antennas and subcarriers. This reflects real-world wireless channel dependencies more accurately, as the relationship between antennas and subcarriers is critical in practical deployments.
102
 
103
  This multidimensional segmentation increases the complexity of the MCM task, requiring the model to learn deeper and more meaningful dependencies within the data. By better aligning the training methodology with real-world conditions, LWM-v1.1 further enhances its ability to generalize and perform in practical scenarios.
104
 
105
- ---
106
-
107
  ### **Optimized Training Strategy**
108
  Training large models requires carefully designed optimization techniques to ensure smooth convergence and generalization. In **LWM-v1.1**, we adopted the **AdamW optimizer**, which improves weight regularization and prevents overfitting compared to traditional Adam. The learning rate schedule was also refined, incorporating an **85-step warmup phase** followed by **cosine decay**. This strategy ensures that the model transitions smoothly from the initial training phase to convergence, maintaining stability and improving overall performance.
109
 
110
- ---
111
-
112
  ### **Improved Computational Efficiency**
113
  To balance computational efficiency with performance, we reduced the number of **attention heads per layer from 12 to 8** in **LWM-v1.1**. This reduction decreases the computational load during both training and inference, making the model more efficient without significantly affecting its ability to extract meaningful features. The streamlined architecture ensures that LWM-v1.1 is not only powerful but also practical for deployment in resource-constrained environments.
114
 
115
- ---
116
-
117
  ### **Why These Changes Were Necessary**
118
  The updates in LWM-v1.1 were driven by real-world demands for greater flexibility, scalability, and performance in wireless communication tasks. Removing channel size limitations and diversifying the dataset address the variability inherent in wireless environments. Increasing model capacity and enhancing the MCM task improve the quality of embeddings, while optimized training strategies and computational efficiency make the model practical for a wide range of applications. These changes make LWM-v1.1 a significant step forward, ensuring its relevance and impact in advancing wireless communication research.
119
 
120
- ---
121
-
122
  ## **Conclusion**
123
  **LWM-v1.1** represents a major leap forward in wireless communication modeling, offering robust scalability, increased generalization, and adaptability to a wide variety of tasks. From enriched training datasets and challenging pretraining objectives to enhanced model capacity and efficient input handling, LWM-v1.1 provides a powerful foundation for wireless communication research and applications.
124
 
125
- ---
126
-
127
  ### **Try It Now!**
128
  Explore **LWM-v1.1** on Hugging Face with preloaded datasets, fine-tuning options, and pretrained models to kickstart your projects.
129
  [👉 Access the model here!](https://huggingface.co/wi-lab/lwm-v1.1)
130
 
131
  ---
132
 
133
- Please cite the following paper if you use the LWM model or any modifiled parts:
134
  ```
135
  @misc{alikhani2024largewirelessmodellwm,
136
  title={Large Wireless Model (LWM): A Foundation Model for Wireless Channels},
@@ -142,3 +120,486 @@ Please cite the following paper if you use the LWM model or any modifiled parts:
142
  url={https://arxiv.org/abs/2411.08872},
143
  }
144
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  This approach eliminates the rigidity of fixed channel sizes and positions LWM-v1.1 as a versatile model capable of adapting to real-world wireless systems with varying configurations.
64
 
 
 
65
  ### **Larger and More Diverse Pretraining Dataset**
66
  Generalization is a critical aspect of any foundation model. In **LWM-v1.1**, we significantly expanded the training dataset to cover more diverse scenarios and environments. We added **seven new city scenarios**—Charlotte, Denver, Oklahoma, Indianapolis, Fort Worth, Santa Clara, and San Diego—to enrich the model’s exposure to a variety of urban layouts. To enhance the spatial resolution of the training data, we reduced the grid spacing between user locations in the DeepMIMO city scenarios from **2.5m to 1m**, resulting in a higher density of user positions. This adjustment required re-performing ray tracing for all scenarios to generate high-resolution wireless channel data.
67
 
68
  Additionally, we introduced **channels from multiple base stations** in each scenario, with distinct (N, SC) pairs to ensure the model encounters a broad range of channel characteristics. This diversity mirrors the variability found in real-world deployments, such as urban, suburban, and rural environments. By exposing LWM-v1.1 to this diversity, the model gains the ability to generalize across environments with distinct propagation characteristics, making it more reliable and versatile.
69
 
 
 
70
  ### **Fine-Tuning for Task-Specific Embedding Generation**
71
  While pretraining provides a robust feature extractor, downstream tasks often require tailored embeddings. In **LWM-v1.1**, we introduced **fine-tuning options** that give users the flexibility to customize the model for specific tasks. Users can now **freeze specific layers** of the model, allowing the remaining layers to adapt to task-specific requirements. This feature is particularly valuable for tasks prone to overfitting, such as **LoS/NLoS classification**, where excessive training on all layers can lead to suboptimal generalization.
72
 
73
  To further streamline task-specific adaptation, we provided **default classification and regression heads** for downstream tasks. Users can also define their own custom heads to suit unique requirements, ensuring maximum flexibility and adaptability.
74
 
 
 
75
  ### **Increased Model Capacity**
76
  LWM-v1.1 significantly enhances the model's ability to extract complex features by increasing the **embedding size from 64 to 128**. This increase more than quadruples the model's parameter count, raising it from **600K to 2.5M**. The larger embedding size allows the model to represent more intricate relationships within channel data, improving its performance on challenging tasks such as **beam prediction** and **channel estimation**.
77
 
78
  This change directly impacts the quality of the embeddings, making them more expressive and robust across a variety of downstream tasks, even in scenarios with limited labeled data.
79
 
 
 
80
  ### **Challenging MCM Task with Higher Masking Ratio**
81
  The **Masked Channel Modeling (MCM)** task lies at the core of LWM’s pretraining methodology. In **LWM-v1.1**, we made the task more challenging by increasing the **masking ratio from 15% to 40%**. This means that a larger portion of the channel data is masked during training, requiring the model to infer the missing information from contextual dependencies.
82
 
83
  This enhancement forces the model to rely on deeper spatial relationships between antennas and subcarriers, rather than learning superficial patterns. As a result, LWM-v1.1 produces embeddings that are more robust and better equipped to handle real-world scenarios with incomplete or noisy data.
84
 
 
 
85
  ### **Support for Larger Input Sizes**
86
  Wireless communication systems are increasingly handling larger channels with higher dimensions. To accommodate these demands, we increased the **maximum sequence length** from **128 to 512** in **LWM-v1.1**. This change enables the model to process larger and more detailed channel data without modification, broadening its applicability to high-dimensional wireless tasks. This ensures that LWM-v1.1 remains relevant as the scale and complexity of wireless systems continue to grow.
87
 
 
 
88
  ### **2D Patch Segmentation for Realistic Learning**
89
  In **LWM-v1.0**, patches were segmented based on a single dimension, typically grouping elements from different subcarriers within the same antenna. In **LWM-v1.1**, we introduced **2D patch segmentation**, where patches now combine elements from both antennas and subcarriers. This reflects real-world wireless channel dependencies more accurately, as the relationship between antennas and subcarriers is critical in practical deployments.
90
 
91
  This multidimensional segmentation increases the complexity of the MCM task, requiring the model to learn deeper and more meaningful dependencies within the data. By better aligning the training methodology with real-world conditions, LWM-v1.1 further enhances its ability to generalize and perform in practical scenarios.
92
 
 
 
93
  ### **Optimized Training Strategy**
94
  Training large models requires carefully designed optimization techniques to ensure smooth convergence and generalization. In **LWM-v1.1**, we adopted the **AdamW optimizer**, which improves weight regularization and prevents overfitting compared to traditional Adam. The learning rate schedule was also refined, incorporating an **85-step warmup phase** followed by **cosine decay**. This strategy ensures that the model transitions smoothly from the initial training phase to convergence, maintaining stability and improving overall performance.
95
 
 
 
96
  ### **Improved Computational Efficiency**
97
  To balance computational efficiency with performance, we reduced the number of **attention heads per layer from 12 to 8** in **LWM-v1.1**. This reduction decreases the computational load during both training and inference, making the model more efficient without significantly affecting its ability to extract meaningful features. The streamlined architecture ensures that LWM-v1.1 is not only powerful but also practical for deployment in resource-constrained environments.
98
 
 
 
99
  ### **Why These Changes Were Necessary**
100
  The updates in LWM-v1.1 were driven by real-world demands for greater flexibility, scalability, and performance in wireless communication tasks. Removing channel size limitations and diversifying the dataset address the variability inherent in wireless environments. Increasing model capacity and enhancing the MCM task improve the quality of embeddings, while optimized training strategies and computational efficiency make the model practical for a wide range of applications. These changes make LWM-v1.1 a significant step forward, ensuring its relevance and impact in advancing wireless communication research.
101
 
 
 
102
  ## **Conclusion**
103
  **LWM-v1.1** represents a major leap forward in wireless communication modeling, offering robust scalability, increased generalization, and adaptability to a wide variety of tasks. From enriched training datasets and challenging pretraining objectives to enhanced model capacity and efficient input handling, LWM-v1.1 provides a powerful foundation for wireless communication research and applications.
104
 
 
 
105
  ### **Try It Now!**
106
  Explore **LWM-v1.1** on Hugging Face with preloaded datasets, fine-tuning options, and pretrained models to kickstart your projects.
107
  [👉 Access the model here!](https://huggingface.co/wi-lab/lwm-v1.1)
108
 
109
  ---
110
 
111
+ Please cite the following paper if you use the LWM model or any modified parts:
112
  ```
113
  @misc{alikhani2024largewirelessmodellwm,
114
  title={Large Wireless Model (LWM): A Foundation Model for Wireless Channels},
 
120
  url={https://arxiv.org/abs/2411.08872},
121
  }
122
  ```
123
+
124
+ ---
125
+
126
+ ## 🛠 **How to Use**
127
+
128
+ ### 1. **Install Conda**
129
+
130
+ First, ensure that you have a package manager like **Conda** installed to manage your Python environments and packages. You can install **Conda** via **Anaconda** or **Miniconda**.
131
+
132
+ - **Anaconda** includes a comprehensive scientific package suite. Download it [here](https://www.anaconda.com/products/distribution).
133
+ - **Miniconda** is a lightweight version that includes only Conda and Python. Download it [here](https://docs.conda.io/en/latest/miniconda.html).
134
+
135
+ Once installed, you can use Conda to manage environments.
136
+
137
+ ---
138
+
139
+ ### 2. **Create a New Environment**
140
+
141
+ After installing Conda, follow these steps to create a new environment and install the required packages.
142
+
143
+ #### **Step 1: Create a new environment**
144
+
145
+ To begin, open the **Anaconda PowerShell Prompt** and create a new Conda environment named `lwm_env`:
146
+
147
+ ```bash
148
+ conda create -n lwm_env
149
+ ```
150
+
151
+ #### **Step 2: Activate the environment**
152
+
153
+ Activate the environment:
154
+
155
+ ```bash
156
+ conda activate lwm_env
157
+ ```
158
+
159
+ ---
160
+
161
+ ### 3. **Install Required Packages**
162
+
163
+ Once the environment is activated, install the necessary packages.
164
+
165
+ #### **Install CUDA-enabled PyTorch**
166
+
167
+ Although inference can run efficiently on a CPU, you may need a GPU for training more resource-intensive downstream tasks. Visit [this page](https://pytorch.org/get-started/locally/) and select the appropriate options based on your system's specifications. The website will generate a tailored installation command.
168
+
169
+ For instance, on an NVIDIA system, you can use a command like the following with the appropriate CUDA version for your system:
170
+
171
+ ```bash
172
+ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
173
+ ```
174
+
175
+ This command installs PyTorch with CUDA support for GPU-accelerated training. Ensure that the specified CUDA version is compatible with your system, adjusting it if necessary.
176
+
177
+ > **Note:** If you encounter issues installing CUDA-enabled PyTorch, verify your CUDA version compatibility. It might also be due to conflicting installation attempts—try a fresh environment.
178
+
179
+ #### **Install Other Required Packages via Conda Forge**
180
+
181
+ ```bash
182
+ conda install python numpy pandas matplotlib tqdm -c conda-forge
183
+ ```
184
+
185
+ #### **Install DeepMIMOv3 with pip**
186
+
187
+ ```bash
188
+ pip install DeepMIMOv3
189
+ ```
190
+
191
+ ---
192
+
193
+ ### 4. **Clone the Dataset Scenarios**
194
+
195
+ The following functions will help you clone specific dataset scenarios from a repository:
196
+
197
+ ```python
198
+ import subprocess
199
+ import os
200
+
201
+ # Function to clone a specific dataset scenario folder
202
+ def clone_dataset_scenario(scenario_name, repo_url, model_repo_dir="./LWM", scenarios_dir="scenarios"):
203
+ current_dir = os.path.basename(os.getcwd())
204
+ if current_dir == "LWM":
205
+ model_repo_dir = "."
206
+
207
+ # Create the scenarios directory if it doesn't exist
208
+ scenarios_path = os.path.join(model_repo_dir, scenarios_dir)
209
+ if not os.path.exists(scenarios_path):
210
+ os.makedirs(scenarios_path)
211
+
212
+ scenario_path = os.path.join(scenarios_path, scenario_name)
213
+
214
+ # Initialize sparse checkout for the dataset repository
215
+ if not os.path.exists(os.path.join(scenarios_path, ".git")):
216
+ print(f"Initializing sparse checkout in {scenarios_path}...")
217
+ subprocess.run(["git", "clone", "--sparse", repo_url, "."], cwd=scenarios_path, check=True)
218
+ subprocess.run(["git", "sparse-checkout", "init", "--cone"], cwd=scenarios_path, check=True)
219
+ subprocess.run(["git", "lfs", "install"], cwd=scenarios_path, check=True) # Install Git LFS if needed
220
+
221
+ # Add the requested scenario folder to sparse checkout
222
+ print(f"Adding {scenario_name} to sparse checkout...")
223
+ subprocess.run(["git", "sparse-checkout", "add", scenario_name], cwd=scenarios_path, check=True)
224
+
225
+ # Pull large files if needed (using Git LFS)
226
+ subprocess.run(["git", "lfs", "pull"], cwd=scenarios_path, check=True)
227
+
228
+ print(f"Successfully cloned {scenario_name} into {scenarios_path}.")
229
+
230
+ def clone_dataset_scenarios(selected_scenario_names, dataset_repo_url, model_repo_dir):
231
+ for scenario_name in selected_scenario_names:
232
+ clone_dataset_scenario(scenario_name, dataset_repo_url, model_repo_dir)
233
+ ```
234
+
235
+ ---
236
+
237
+ ### 5. **Clone the Model Repository**
238
+
239
+ Now, clone the **LWM-v1.1** model repository to your local system.
240
+
241
+ ```bash
242
+ # Step 1: Clone the model repository (if not already cloned)
243
+ model_repo_url = "https://huggingface.co/wi-lab/lwm-v1.1"
244
+ model_repo_dir = "./LWM-v1.1"
245
+
246
+ if not os.path.exists(model_repo_dir):
247
+ print(f"Cloning model repository from {model_repo_url}...")
248
+ subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
249
+ ```
250
+
251
+ ---
252
+
253
+ ### 6. **Clone the Desired Dataset Scenarios**
254
+
255
+ You can now clone specific scenarios from the DeepMIMO dataset, as detailed in the table below:
256
+
257
+ 📊 **Dataset Overview**
258
+
259
+ | 📊 **Dataset** | 🏙️ **City** | 👥 **Number of Users** | 🔗 **DeepMIMO Page** |
260
+ |----------------|----------------------|------------------------|------------------------------------------------------------------------------------------------------------|
261
+ | Dataset 0 | 🌆 Denver | 1354 | [DeepMIMO City Scenario 18](https://www.deepmimo.net/scenarios/deepmimo-city-scenario18/) |
262
+ | Dataset 1 | 🏙️ Indianapolis | 3248 | [DeepMIMO City Scenario 15](https://www.deepmimo.net/scenarios/deepmimo-city-scenario15/) |
263
+ | Dataset 2 | 🌇 Oklahoma | 3455 | [DeepMIMO City Scenario 19](https://www.deepmimo.net/scenarios/deepmimo-city-scenario19/) |
264
+ | Dataset 3 | 🌆 Fort Worth | 1902 | [DeepMIMO City Scenario 12](https://www.deepmimo.net/scenarios/deepmimo-city-scenario12/) |
265
+ | Dataset 4 | 🌉 Santa Clara | 2689 | [DeepMIMO City Scenario 11](https://www.deepmimo.net/scenarios/deepmimo-city-scenario11/) |
266
+ | Dataset 5 | 🌅 San Diego | 2192 | [DeepMIMO City Scenario 7](https://www.deepmimo.net/scenarios/deepmimo-city-scenario7/) |
267
+
268
+ It is important to note that these six datasets were **not** used during the pre-training of the LWM model, and the high-quality embeddings produced are a testament to LWM’s robust generalization capabilities rather than overfitting.
269
+
270
+ If you plan to use custom datasets, please ensure that your complex channel contains at most **8196 elements** (N * SC <= 8196). In **LWM-v1.0**, the input was restricted to complex channels of size (N, SC) = (32, 32). However, with **LWM-v1.1**, you can now feed complex channels of arbitrary sizes, providing greater flexibility for your specific use case! 😊
271
+
272
+ #### **Clone the Scenarios:**
273
+ ```python
274
+ import numpy as np
275
+ dataset_repo_url = "https://huggingface.co/datasets/wi-lab/lwm" # Base URL for dataset repo
276
+ scenario_names = np.array([
277
+ "city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
278
+ "city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
279
+ ])
280
+
281
+ scenario_idxs = np.array([3]) # Select the scenario index
282
+ selected_scenario_names = scenario_names[scenario_idxs]
283
+
284
+ # Clone the requested scenarios
285
+ clone_dataset_scenarios(selected_scenario_names, dataset_repo_url, model_repo_dir)
286
+ ```
287
+
288
+ ---
289
+
290
+ ## **7. Change the Working Directory to LWM**
291
+
292
+ Before proceeding, ensure you are in the correct working directory for the **LWM** repository:
293
+
294
+ ```python
295
+ import os
296
+
297
+ if os.path.exists(model_repo_dir):
298
+ os.chdir(model_repo_dir)
299
+ print(f"Changed working directory to {os.getcwd()}")
300
+ else:
301
+ print(f"Directory {model_repo_dir} does not exist. Please check if the repository is cloned properly.")
302
+ ```
303
+
304
+ This ensures that all paths and dependencies align with the repository structure.
305
+
306
+ ---
307
+
308
+ ## **Downstream Tasks**
309
+
310
+ ### **Loading Required Packages and Modules**
311
+
312
+ To set up your environment for downstream tasks, import the necessary modules and suppress unnecessary warnings:
313
+
314
+ ```python
315
+ from input_preprocess import tokenizer, scenarios_list
316
+ from inference import lwm_inference
317
+ from utils import prepare_loaders
318
+ from train import finetune
319
+ import lwm_model
320
+ import matplotlib.pyplot as plt
321
+ import numpy as np
322
+ import torch
323
+ import torch.nn as nn
324
+ import warnings
325
+
326
+ warnings.filterwarnings("ignore", category=UserWarning)
327
+ ```
328
+
329
+ ### **Setting Parameters for Downstream Tasks**
330
+
331
+ Define the parameters for your downstream task. This includes selecting the desired task, visualization method, and data input types. Additionally, you can either use default tasks or manually define labels for custom tasks. If your primary goal is to extract **LWM embeddings**, you can skip task definitions and labels.
332
+
333
+ ```python
334
+ n_beams = 16
335
+ task = ['Beam Prediction', 'LoS/NLoS Classification'][1] # Default: LoS/NLoS Classification
336
+ task_type = ["classification", "regression"][0] # Default: Classification
337
+ visualization_method = ["pca", "umap", "tsne"][2] # Default: TSNE
338
+ input_types = ["cls_emb", "channel_emb", "raw"] # Supported input types
339
+ train_ratios = [.001, .01, .05, .1, .25, .5, .8] # Fraction of data for training
340
+ fine_tuning_status = [None, ["layers.8", "layers.9", "layers.10", "layers.11"], "full"] # Fine-tuning configurations
341
+ selected_scenario_names = [scenarios_list()[18]] # Choose a specific scenario
342
+
343
+ preprocessed_data, labels, raw_chs = tokenizer(
344
+ selected_scenario_names,
345
+ bs_idxs=[3],
346
+ load_data=False,
347
+ task=task,
348
+ n_beams=n_beams
349
+ )
350
+ ```
351
+
352
+ #### **Parameters**
353
+
354
+ 1. **`n_beams`**:
355
+ - Specifies the number of beams in the codebook for the **Beam Prediction** task.
356
+ - For example, `16` beams indicate 16 possible output classes when predicting the optimal beam index.
357
+
358
+ 2. **`task`**:
359
+ - Defines the downstream task to perform:
360
+ - `'Beam Prediction'`: Predicts the optimal beam index from sub-6GHz channels for mmWave communications.
361
+ - `'LoS/NLoS Classification'`: Classifies channels into **Line-of-Sight (LoS)** or **Non-Line-of-Sight (NLoS)**.
362
+ - Here, **LoS/NLoS Classification** is selected (`[1]`).
363
+
364
+ 3. **`task_type`**:
365
+ - Specifies whether the task involves **classification** (discrete outputs) or **regression** (continuous outputs).
366
+ - In this case, the task is a **classification problem** (`[0]`).
367
+
368
+ 4. **`visualization_method`**:
369
+ - Determines how the channel embeddings will be visualized during evaluation:
370
+ - `"pca"`: Principal Component Analysis for linear dimensionality reduction.
371
+ - `"umap"`: Uniform Manifold Approximation and Projection for capturing non-linear structures.
372
+ - `"tsne"`: t-distributed Stochastic Neighbor Embedding, ideal for clustering visualization.
373
+ - Here, **t-SNE** is used (`[2]`).
374
+
375
+ 5. **`input_types`**:
376
+ - Lists the types of inputs supported by the model:
377
+ - `"cls_emb"`: CLS token embeddings of size (n_samples, 128) representing holistic channel features.
378
+ - `"channel_emb"`: Lower-level embeddings of szie (n_samples, n_patches, 128) derived from channel patches.
379
+ - `"raw"`: Raw wireless channel data without preprocessing.
380
+ - These input types enable flexibility in evaluating and fine-tuning the model.
381
+
382
+ 6. **`train_ratios`**:
383
+ - Specifies the fraction of the dataset used for training:
384
+ - Values like `0.001` (0.1%) simulate data-limited scenarios, while `0.8` (80%) allows training with most of the dataset.
385
+ - This parameter is particularly useful for analyzing model performance under varying levels of labeled data availability. The LWM model is proven to perform most effectively compared to raw channel representations in data-limited scenarios.
386
+
387
+ 7. **`fine_tuning_status`**:
388
+ - Determines how the pretrained **LWM-v1.1** model will be fine-tuned:
389
+ - `None`: Uses the pretrained model as-is, without fine-tuning.
390
+ - `["layers.8", "layers.9", "layers.10", "layers.11"]`: Fine-tunes only the last four encoder layers, suitable for task-specific adaptation. The set of desired layers can be selected ("layers.0" to "layers.11)".
391
+ - `"full"`: Fine-tunes the entire model, ideal for significant task adaptation.
392
+ - These configurations help balance performance improvements with computational efficiency.
393
+
394
+ 8. **`selected_scenario_names`**:
395
+ - Specifies the scenario(s) from the dataset to use for training and evaluation.
396
+ - **`scenarios_list()`**: A utility function that provides all available scenarios in the dataset.
397
+ - `[18]`: Selects the 18th scenario, which corresponds to a specific wireless environment and base station configuration. In this case, scenario 18 represents channels of size (128, 32) between BS 3 and users in the densified Denver scenario.
398
+
399
+ ---
400
+
401
+ #### **Preprocessing**
402
+
403
+ The `tokenizer` function processes the raw wireless channel data based on the selected parameters:
404
+
405
+ ```python
406
+ preprocessed_data, labels, raw_chs = tokenizer(
407
+ selected_scenario_names,
408
+ bs_idxs=[3],
409
+ load_data=False,
410
+ task=task,
411
+ n_beams=n_beams
412
+ )
413
+ ```
414
+
415
+ 1. **`selected_scenario_names`**: Defines the scenario(s) to tokenize.
416
+ 2. **`bs_idxs`**: Specifies the base station(s) to include in the scenario.
417
+ - `[3]`: Includes only the 3rd base station.
418
+ 3. **`load_data`**:
419
+ - `False`: Specifies that the function should generate the densified DeepMIMO scenario and save it. If the scenario has already been pre-saved, set this parameter to `True`.
420
+ 4. **`task`**: Sets the downstream task (e.g., Beam Prediction or LoS/NLoS Classification).
421
+ 5. **`n_beams`**: Specifies the number of beams for **Beam Prediction** tasks.
422
+
423
+ **Outputs**:
424
+ - **`preprocessed_data`**: Tokenized wireless channel data, formatted for the model.
425
+ - **`labels`**: Labels corresponding to the task (e.g., beam indexes or LoS/NLoS categories).
426
+ - **`raw_chs`**: Original raw wireless channel data for comparison or visualization.
427
+
428
+ ---
429
+
430
+ ### **Loading the Pretrained LWM-v1.1 Model**
431
+
432
+ Load the **LWM-v1.1** pretrained model and prepare it for downstream tasks. The model is initialized on the specified GPU(s) or CPU if no GPU is available.
433
+
434
+ ```python
435
+ from lwm_model import lwm # Adjust the import path as needed
436
+
437
+ gpu_ids = [0]
438
+ device = torch.device(f"cuda:{gpu_ids[0]}" if torch.cuda.is_available() else "cpu")
439
+
440
+ # Initialize the model
441
+ model = lwm().to(device)
442
+
443
+ # Load the pretrained model state
444
+ model_name = "model.pth"
445
+ state_dict_path = f"models/{model_name}"
446
+ state_dict = torch.load(state_dict_path, map_location=device)
447
+
448
+ # Clean state dictionary for DataParallel compatibility
449
+ clean_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
450
+ model.load_state_dict(clean_state_dict)
451
+
452
+ # Use multiple GPUs if specified
453
+ if len(gpu_ids) > 1:
454
+ model = nn.DataParallel(model, device_ids=gpu_ids)
455
+
456
+ print(f"Model loaded successfully on device: {device}")
457
+ ```
458
+
459
+ ---
460
+
461
+ ### **Visualizing the Original Channel and Embedding Spaces**
462
+
463
+ If you wish to visualize how the original channel space and embedding space align with task labels before fine-tuning, or if you simply want to perform inference on raw channels:
464
+
465
+ ```python
466
+ chs = lwm_inference(
467
+ model,
468
+ preprocessed_data,
469
+ input_type="cls_emb",
470
+ device=device,
471
+ batch_size=64,
472
+ visualization=True,
473
+ labels=labels,
474
+ visualization_method=visualization_method
475
+ )
476
+ ```
477
+
478
+ This generates embeddings or visualizations, depending on your configuration.
479
+
480
+ ---
481
+
482
+ ### **Fine-Tuning the Pretrained Model**
483
+
484
+ Fine-tune the **LWM-v1.1** model for your specific downstream task. You can choose to leave the pretrained model unchanged, fine-tune specific encoder layers, or fine-tune the entire model. Avoid over-parameterizing the downstream model to maintain generalization.
485
+
486
+ ```python
487
+ results = np.zeros((len(fine_tuning_status), len(input_types), len(train_ratios)))
488
+
489
+ for fine_tuning_stat_idx, fine_tuning_stat in enumerate(fine_tuning_status):
490
+ for input_type_idx, input_type in enumerate(input_types):
491
+
492
+ if input_type == "raw" and fine_tuning_stat is not None:
493
+ continue
494
+
495
+ selected_patches_idxs = None
496
+ for train_ratio_idx, train_ratio in enumerate(train_ratios):
497
+ print(f"\nfine-tuning status: {fine_tuning_stat}")
498
+ print(f"input type: {input_type}")
499
+ print(f"train ratio: {train_ratio}\n")
500
+
501
+ # Prepare data loaders
502
+ train_loader, val_loader, samples, target = prepare_loaders(
503
+ preprocessed_data=preprocessed_data,
504
+ labels=labels,
505
+ selected_patches_idxs=selected_patches_idxs,
506
+ input_type=input_type,
507
+ task_type=task_type,
508
+ train_ratio=train_ratio,
509
+ batch_size=128,
510
+ seed=42
511
+ )
512
+
513
+ # Fine-tune LWM
514
+ fine_tuned_model, best_model_path, train_losses, val_losses, f1_scores, attn_maps_ft = finetune(
515
+ base_model=model,
516
+ train_loader=train_loader,
517
+ val_loader=val_loader,
518
+ task_type=task_type,
519
+ input_type=input_type,
520
+ num_classes=n_beams if task == 'Beam Prediction' else 2 if task == 'LoS/NLoS Classification' else None,
521
+ output_dim=target.shape[-1] if task_type == 'regression' else None,
522
+ use_custom_head=True,
523
+ fine_tune_layers=fine_tuning_stat,
524
+ optimizer_config={"lr": 1e-3},
525
+ epochs=15,
526
+ device=device,
527
+ task=task
528
+ )
529
+
530
+ results[fine_tuning_stat_idx][input_type_idx][train_ratio_idx] = f1_scores[-1]
531
+ ```
532
+
533
+ ---
534
+
535
+ ### **Visualizing Fine-Tuning Results**
536
+
537
+ Visualize the effect of fine-tuning on performance across different training ratios, input types, and fine-tuning configurations:
538
+
539
+ ```python
540
+ markers = ['o', 's', 'D']
541
+ labels = ['CLS Emb', 'CHS Emb', 'Raw']
542
+ fine_tuning_status_labels = ['No FT', 'Partial FT', 'Full FT']
543
+ line_styles = ['-', '--', ':']
544
+ colors = plt.cm.viridis(np.linspace(0, 0.8, len(labels)))
545
+
546
+ plt.figure(figsize=(12, 8), dpi=500)
547
+ for ft_idx, (ft_status_label, line_style) in enumerate(zip(fine_tuning_status_labels, line_styles)):
548
+ for idx, (marker, label, color) in enumerate(zip(markers, labels, colors)):
549
+ if label == "Raw" and ft_status_label != "No FT":
550
+ continue
551
+ plt.plot(
552
+ train_ratios,
553
+ results[ft_idx, idx],
554
+ marker=marker,
555
+ linestyle=line_style,
556
+ label=f"{label} ({ft_status_label})",
557
+ color=color,
558
+ linewidth=3,
559
+ markersize=9
560
+ )
561
+ plt.xscale('log')
562
+ plt.xlabel("Train Ratio", fontsize=20)
563
+ plt.ylabel("F1-Score", fontsize=20)
564
+ plt.legend(fontsize=17, loc="best")
565
+ plt.grid(True, linestyle="--", alpha=0.7)
566
+ plt.xticks(fontsize=17)
567
+ plt.yticks(fontsize=17)
568
+ plt.tight_layout()
569
+ plt.show()
570
+ ```
571
+
572
+ ---
573
+
574
+ ### **Comparing the Original Channel Space with Fine-Tuned Embedding Space**
575
+
576
+ After fine-tuning, compare how the embedding space has adapted to task-specific details:
577
+
578
+ ```python
579
+ chs = lwm_inference(
580
+ fine_tuned_model.model,
581
+ preprocessed_data,
582
+ input_type="cls_emb",
583
+ device=device,
584
+ batch_size=64,
585
+ visualization=False,
586
+ labels=labels,
587
+ visualization_method=visualization_method
588
+ )
589
+ ```
590
+
591
+ ---
592
+
593
+ ### **12. Explore the Interactive Demo**
594
+
595
+ Experience **LWM** interactively via our Hugging Face Spaces demo:
596
+ [**Try the Interactive Demo!**](https://huggingface.co/spaces/wi-lab/lwm-interactive-demo)
597
+
598
+ ---
599
+
600
+ You are now ready to explore the power of **LWM** in wireless communications! Start processing datasets and generate high-quality embeddings to advance your research or applications.
601
+
602
+ If you have questions or need assistance, feel free to:
603
+ - Visit the [Hugging Face Discussions](https://huggingface.co/wi-lab/lwm/discussions) for community support.
604
+ - Check out the [LWM website FAQ](https://lwm-wireless.net/community).
605
+ - Contact us directly via email at [[email protected]](mailto:[email protected]).