Update README.md
Browse files
README.md
CHANGED
@@ -114,7 +114,7 @@ The updates in LWM-v1.1 were driven by real-world demands for greater flexibilit
|
|
114 |
|
115 |
### **Try It Now!**
|
116 |
Explore **LWM-v1.1** on Hugging Face with preloaded datasets, fine-tuning options, and pretrained models to kickstart your projects.
|
117 |
-
[👉 Access the model here!](https://huggingface.co/wi-lab/lwm-v1.1)
|
118 |
|
119 |
---
|
120 |
|
@@ -315,7 +315,16 @@ This ensures that all paths and dependencies align with the repository structure
|
|
315 |
|
316 |
---
|
317 |
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
### **Loading Required Packages and Modules**
|
321 |
|
@@ -493,8 +502,6 @@ This generates embeddings or visualizations, depending on your configuration. Fo
|
|
493 |
|:---------------------------------------------:|:---------------------------------------------:|:---------------------------------------------:|
|
494 |
| **Raw Channels** | **General-purpose Embeddings** | **Task-specific Embeddings** |
|
495 |
|
496 |
-
---
|
497 |
-
|
498 |
### **Beam Prediction Task**
|
499 |
|
500 |
|  |  |  |
|
@@ -615,7 +622,233 @@ chs = lwm_inference(
|
|
615 |
|
616 |
---
|
617 |
|
618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
|
620 |
Experience **LWM** interactively via our Hugging Face Spaces demo:
|
621 |
[**Try the Interactive Demo!**](https://huggingface.co/spaces/wi-lab/lwm-interactive-demo)
|
|
|
114 |
|
115 |
### **Try It Now!**
|
116 |
Explore **LWM-v1.1** on Hugging Face with preloaded datasets, fine-tuning options, and pretrained models to kickstart your projects.
|
117 |
+
[👉 Access the model here!](https://huggingface.co/wi-lab/lwm-v1.1/tree/main)
|
118 |
|
119 |
---
|
120 |
|
|
|
315 |
|
316 |
---
|
317 |
|
318 |
+
Next, we proceed in two distinct directions, each focusing on a critical aspect of **LWM-v1.1**:
|
319 |
+
|
320 |
+
1. **INFERENCE AND DOWNSTREAM TASKS**: Utilize the pre-trained LWM-v1.1 model to perform inference and adapt it for specific tasks such as classification or regression.
|
321 |
+
2. **PRE-TRAINING LWM-v1.1**: Explore the process of pre-training the model from scratch, including the techniques and datasets used to develop its foundational capabilities.
|
322 |
+
|
323 |
+
The corresponding scripts for these processes can be found in the **`downstream.py`** and **`main.py`** files available at [**Hugging Face Repository**](https://huggingface.co/wi-lab/lwm-v1.1/tree/main).
|
324 |
+
|
325 |
+
---
|
326 |
+
|
327 |
+
## **1. INFERENCE & DOWNSTREAM TASKS**
|
328 |
|
329 |
### **Loading Required Packages and Modules**
|
330 |
|
|
|
502 |
|:---------------------------------------------:|:---------------------------------------------:|:---------------------------------------------:|
|
503 |
| **Raw Channels** | **General-purpose Embeddings** | **Task-specific Embeddings** |
|
504 |
|
|
|
|
|
505 |
### **Beam Prediction Task**
|
506 |
|
507 |
|  |  |  |
|
|
|
622 |
|
623 |
---
|
624 |
|
625 |
+
## **2. PRE-TRAINING LWM-v1.1**
|
626 |
+
|
627 |
+
This section details the process of pre-training the **LWM-v1.1** model, including data preparation, model initialization, and optimization settings. Each step has been carefully designed to enable the model to learn robust and general-purpose embeddings for wireless channel data.
|
628 |
+
|
629 |
+
---
|
630 |
+
|
631 |
+
### **Loading Required Packages and Modules**
|
632 |
+
|
633 |
+
The following packages are required to preprocess data, initialize the model, and train it effectively:
|
634 |
+
|
635 |
+
```python
|
636 |
+
import torch
|
637 |
+
import torch.nn as nn
|
638 |
+
from torch.utils.data import random_split
|
639 |
+
from input_preprocess import tokenizer, scenarios_list
|
640 |
+
from utils import create_dataloader, count_parameters
|
641 |
+
import numpy as np
|
642 |
+
import lwm_model
|
643 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
|
644 |
+
from torch.optim import AdamW
|
645 |
+
from train import train_lwm
|
646 |
+
import warnings
|
647 |
+
|
648 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
649 |
+
```
|
650 |
+
|
651 |
+
---
|
652 |
+
|
653 |
+
### **Settings**
|
654 |
+
|
655 |
+
Set the key hyperparameters for pretraining:
|
656 |
+
|
657 |
+
```python
|
658 |
+
EPOCHS = 50
|
659 |
+
BATCH_SIZE = 128
|
660 |
+
VAL_BATCH_SIZE = 64
|
661 |
+
WARMUP_EPOCHS = 5
|
662 |
+
BASE_LR = 5e-4
|
663 |
+
N_ROWS = 4
|
664 |
+
N_COLUMNS = 4
|
665 |
+
ELEMENT_LENGTH = N_ROWS * N_COLUMNS * 2
|
666 |
+
D_MODEL = 128
|
667 |
+
MAX_LEN = 513
|
668 |
+
N_LAYERS = 12
|
669 |
+
WEIGHT_DECAY = 0.05
|
670 |
+
BETA1 = 0.9
|
671 |
+
BETA2 = 0.999
|
672 |
+
MASK_PERCENT = 0.40
|
673 |
+
N_HEADS = 8
|
674 |
+
DROPOUT = 0.1
|
675 |
+
```
|
676 |
+
|
677 |
+
- **Data Parameters**:
|
678 |
+
- **`N_ROWS` and `N_COLUMNS`**: Number of rows and columns in each channel patch (4 antennas × 4 subcarriers).
|
679 |
+
- **`ELEMENT_LENGTH`**: Number of elements in each patch, including real and imaginary parts (\(4 \times 4 \times 2 = 32\)).
|
680 |
+
- **`MAX_LEN`**: Maximum input length (including positional encoding).
|
681 |
+
|
682 |
+
- **Model Hyperparameters**:
|
683 |
+
- **`D_MODEL`**: Embedding size (128).
|
684 |
+
- **`N_LAYERS`**: Number of transformer layers (12).
|
685 |
+
- **`N_HEADS`**: Number of attention heads (8).
|
686 |
+
- **`DROPOUT`**: Dropout probability (0.1).
|
687 |
+
|
688 |
+
- **Training Hyperparameters**:
|
689 |
+
- **`EPOCHS`**: Total number of epochs (50).
|
690 |
+
- **`BATCH_SIZE`**: Batch size for training (128) and validation (64).
|
691 |
+
- **`BASE_LR` and `WARMUP_EPOCHS`**: Initial learning rate (5e-4) and warmup period (5 epochs).
|
692 |
+
- **`MASK_PERCENT`**: Percentage of masked patches during pretraining (40%).
|
693 |
+
|
694 |
+
---
|
695 |
+
|
696 |
+
### **Generating the Dataset**
|
697 |
+
|
698 |
+
The dataset is prepared by tokenizing scenarios using the `tokenizer` function:
|
699 |
+
|
700 |
+
```python
|
701 |
+
bs_idxs = [1, 2, 3]
|
702 |
+
selected_scenario_names = scenarios_list()[:80]
|
703 |
+
preprocessed_data = tokenizer(
|
704 |
+
selected_scenario_names,
|
705 |
+
MAX_LEN,
|
706 |
+
masking_percent=MASK_PERCENT,
|
707 |
+
mask=True,
|
708 |
+
seed=42
|
709 |
+
)
|
710 |
+
```
|
711 |
+
|
712 |
+
- **Parameters**:
|
713 |
+
- **`bs_idxs`**: Selects base stations 1, 2, and 3 for data generation.
|
714 |
+
- **`selected_scenario_names`**: Uses the first 80 scenarios from the `scenarios_list`.
|
715 |
+
- **`masking_percent`**: Masks 40% of patches in each channel during pretraining.
|
716 |
+
|
717 |
+
- **Outputs**:
|
718 |
+
- **`preprocessed_data`**: A dictionary where keys are scenario names, and values are preprocessed samples.
|
719 |
+
|
720 |
+
---
|
721 |
+
|
722 |
+
### **Splitting the Dataset**
|
723 |
+
|
724 |
+
Split the dataset into training, validation, and test sets:
|
725 |
+
|
726 |
+
```python
|
727 |
+
SEED = 42
|
728 |
+
torch.manual_seed(SEED)
|
729 |
+
np.random.seed(SEED)
|
730 |
+
train_ratio = 0.8
|
731 |
+
val_ratio = 0.2
|
732 |
+
train_data = {}
|
733 |
+
val_data = {}
|
734 |
+
test_data = {}
|
735 |
+
|
736 |
+
for key, samples in preprocessed_data.items():
|
737 |
+
total_samples = len(samples)
|
738 |
+
train_size = int(train_ratio * total_samples)
|
739 |
+
val_size = int(val_ratio * total_samples)
|
740 |
+
test_size = total_samples - train_size - val_size
|
741 |
+
|
742 |
+
train_data[key], val_data[key], test_data[key] = random_split(
|
743 |
+
samples, [train_size, val_size, test_size]
|
744 |
+
)
|
745 |
+
|
746 |
+
train_loaders = create_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
|
747 |
+
val_loaders = create_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
|
748 |
+
```
|
749 |
+
|
750 |
+
- **Data Ratios**:
|
751 |
+
- **`train_ratio`**: 80% of the data for training.
|
752 |
+
- **`val_ratio`**: 20% for validation.
|
753 |
+
- Remaining samples are reserved for testing.
|
754 |
+
|
755 |
+
- **Data Loaders**:
|
756 |
+
- `train_loaders` and `val_loaders` provide batched data for training and validation.
|
757 |
+
|
758 |
+
---
|
759 |
+
|
760 |
+
### **Initializing the Model**
|
761 |
+
|
762 |
+
Initialize **LWM-v1.1** and optionally load a pretrained checkpoint:
|
763 |
+
|
764 |
+
```python
|
765 |
+
load_model = True
|
766 |
+
gpu_ids = [0]
|
767 |
+
device = torch.device("cuda:0")
|
768 |
+
model = lwm_model.lwm().to(device)
|
769 |
+
|
770 |
+
if load_model:
|
771 |
+
model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
|
772 |
+
state_dict = torch.load(f"models/{model_name}", map_location=device)
|
773 |
+
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
774 |
+
model.load_state_dict(new_state_dict)
|
775 |
+
|
776 |
+
model = nn.DataParallel(model, gpu_ids)
|
777 |
+
print(f"Model loaded successfully on GPU {device.index}")
|
778 |
+
n_parameters = count_parameters(model)
|
779 |
+
print(f"Number of trainable parameters: {n_parameters:,}")
|
780 |
+
```
|
781 |
+
|
782 |
+
- **GPU Handling**:
|
783 |
+
- The model runs on GPU `cuda:0`. It can also use multiple GPUs if specified.
|
784 |
+
|
785 |
+
- **Checkpoint Loading**:
|
786 |
+
- If `load_model` is `True`, a pretrained checkpoint is loaded, ensuring the model starts with learned weights.
|
787 |
+
|
788 |
+
- **Parameter Count**:
|
789 |
+
- Displays the number of trainable parameters for transparency.
|
790 |
+
|
791 |
+
---
|
792 |
+
|
793 |
+
### **Optimizer and Learning Rate Scheduler**
|
794 |
+
|
795 |
+
Define the optimizer and learning rate scheduler:
|
796 |
+
|
797 |
+
```python
|
798 |
+
optimizer = AdamW(
|
799 |
+
model.parameters(),
|
800 |
+
lr=BASE_LR,
|
801 |
+
betas=(BETA1, BETA2),
|
802 |
+
weight_decay=WEIGHT_DECAY
|
803 |
+
)
|
804 |
+
|
805 |
+
def lr_lambda(current_step):
|
806 |
+
if current_step < WARMUP_STEPS:
|
807 |
+
return current_step / WARMUP_STEPS
|
808 |
+
else:
|
809 |
+
scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
|
810 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
|
811 |
+
return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
|
812 |
+
|
813 |
+
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
|
814 |
+
```
|
815 |
+
|
816 |
+
- **AdamW Optimizer**:
|
817 |
+
- Includes weight decay for better generalization.
|
818 |
+
- **Learning Rate Scheduler**:
|
819 |
+
- Combines linear warmup and cosine decay for smooth training.
|
820 |
+
|
821 |
+
---
|
822 |
+
|
823 |
+
### **Training the Model**
|
824 |
+
|
825 |
+
Train the model using the `train_lwm` function:
|
826 |
+
|
827 |
+
```python
|
828 |
+
pretrained_model = train_lwm(
|
829 |
+
model,
|
830 |
+
train_loaders,
|
831 |
+
val_loaders,
|
832 |
+
optimizer,
|
833 |
+
scheduler,
|
834 |
+
EPOCHS,
|
835 |
+
device=device
|
836 |
+
)
|
837 |
+
```
|
838 |
+
|
839 |
+
- **Inputs**:
|
840 |
+
- **`model`**: The initialized LWM model.
|
841 |
+
- **`train_loaders` and `val_loaders`**: Data loaders for training and validation.
|
842 |
+
- **`optimizer` and `scheduler`**: Configured optimizer and learning rate scheduler.
|
843 |
+
- **`EPOCHS`**: Number of training epochs.
|
844 |
+
- **`device`**: Specifies whether training occurs on GPU or CPU.
|
845 |
+
|
846 |
+
- **Output**:
|
847 |
+
- **`pretrained_model`**: The trained LWM-v1.1 model.
|
848 |
+
|
849 |
+
---
|
850 |
+
|
851 |
+
### **Explore the Interactive Demo**
|
852 |
|
853 |
Experience **LWM** interactively via our Hugging Face Spaces demo:
|
854 |
[**Try the Interactive Demo!**](https://huggingface.co/spaces/wi-lab/lwm-interactive-demo)
|