wi-lab commited on
Commit
93d4703
·
verified ·
1 Parent(s): 4bfda2c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +238 -5
README.md CHANGED
@@ -114,7 +114,7 @@ The updates in LWM-v1.1 were driven by real-world demands for greater flexibilit
114
 
115
  ### **Try It Now!**
116
  Explore **LWM-v1.1** on Hugging Face with preloaded datasets, fine-tuning options, and pretrained models to kickstart your projects.
117
- [👉 Access the model here!](https://huggingface.co/wi-lab/lwm-v1.1)
118
 
119
  ---
120
 
@@ -315,7 +315,16 @@ This ensures that all paths and dependencies align with the repository structure
315
 
316
  ---
317
 
318
- ## **Downstream Tasks**
 
 
 
 
 
 
 
 
 
319
 
320
  ### **Loading Required Packages and Modules**
321
 
@@ -493,8 +502,6 @@ This generates embeddings or visualizations, depending on your configuration. Fo
493
  |:---------------------------------------------:|:---------------------------------------------:|:---------------------------------------------:|
494
  | **Raw Channels** | **General-purpose Embeddings** | **Task-specific Embeddings** |
495
 
496
- ---
497
-
498
  ### **Beam Prediction Task**
499
 
500
  | ![Image 4](https://huggingface.co/wi-lab/lwm-v1.1/resolve/main/images/bp_raw.png) | ![Image 5](https://huggingface.co/wi-lab/lwm-v1.1/resolve/main/images/bp_embedding_noFT.png) | ![Image 6](https://huggingface.co/wi-lab/lwm-v1.1/resolve/main/images/bp_embedding_FT.png) |
@@ -615,7 +622,233 @@ chs = lwm_inference(
615
 
616
  ---
617
 
618
- ### **12. Explore the Interactive Demo**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
 
620
  Experience **LWM** interactively via our Hugging Face Spaces demo:
621
  [**Try the Interactive Demo!**](https://huggingface.co/spaces/wi-lab/lwm-interactive-demo)
 
114
 
115
  ### **Try It Now!**
116
  Explore **LWM-v1.1** on Hugging Face with preloaded datasets, fine-tuning options, and pretrained models to kickstart your projects.
117
+ [👉 Access the model here!](https://huggingface.co/wi-lab/lwm-v1.1/tree/main)
118
 
119
  ---
120
 
 
315
 
316
  ---
317
 
318
+ Next, we proceed in two distinct directions, each focusing on a critical aspect of **LWM-v1.1**:
319
+
320
+ 1. **INFERENCE AND DOWNSTREAM TASKS**: Utilize the pre-trained LWM-v1.1 model to perform inference and adapt it for specific tasks such as classification or regression.
321
+ 2. **PRE-TRAINING LWM-v1.1**: Explore the process of pre-training the model from scratch, including the techniques and datasets used to develop its foundational capabilities.
322
+
323
+ The corresponding scripts for these processes can be found in the **`downstream.py`** and **`main.py`** files available at [**Hugging Face Repository**](https://huggingface.co/wi-lab/lwm-v1.1/tree/main).
324
+
325
+ ---
326
+
327
+ ## **1. INFERENCE & DOWNSTREAM TASKS**
328
 
329
  ### **Loading Required Packages and Modules**
330
 
 
502
  |:---------------------------------------------:|:---------------------------------------------:|:---------------------------------------------:|
503
  | **Raw Channels** | **General-purpose Embeddings** | **Task-specific Embeddings** |
504
 
 
 
505
  ### **Beam Prediction Task**
506
 
507
  | ![Image 4](https://huggingface.co/wi-lab/lwm-v1.1/resolve/main/images/bp_raw.png) | ![Image 5](https://huggingface.co/wi-lab/lwm-v1.1/resolve/main/images/bp_embedding_noFT.png) | ![Image 6](https://huggingface.co/wi-lab/lwm-v1.1/resolve/main/images/bp_embedding_FT.png) |
 
622
 
623
  ---
624
 
625
+ ## **2. PRE-TRAINING LWM-v1.1**
626
+
627
+ This section details the process of pre-training the **LWM-v1.1** model, including data preparation, model initialization, and optimization settings. Each step has been carefully designed to enable the model to learn robust and general-purpose embeddings for wireless channel data.
628
+
629
+ ---
630
+
631
+ ### **Loading Required Packages and Modules**
632
+
633
+ The following packages are required to preprocess data, initialize the model, and train it effectively:
634
+
635
+ ```python
636
+ import torch
637
+ import torch.nn as nn
638
+ from torch.utils.data import random_split
639
+ from input_preprocess import tokenizer, scenarios_list
640
+ from utils import create_dataloader, count_parameters
641
+ import numpy as np
642
+ import lwm_model
643
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
644
+ from torch.optim import AdamW
645
+ from train import train_lwm
646
+ import warnings
647
+
648
+ warnings.filterwarnings("ignore", category=UserWarning)
649
+ ```
650
+
651
+ ---
652
+
653
+ ### **Settings**
654
+
655
+ Set the key hyperparameters for pretraining:
656
+
657
+ ```python
658
+ EPOCHS = 50
659
+ BATCH_SIZE = 128
660
+ VAL_BATCH_SIZE = 64
661
+ WARMUP_EPOCHS = 5
662
+ BASE_LR = 5e-4
663
+ N_ROWS = 4
664
+ N_COLUMNS = 4
665
+ ELEMENT_LENGTH = N_ROWS * N_COLUMNS * 2
666
+ D_MODEL = 128
667
+ MAX_LEN = 513
668
+ N_LAYERS = 12
669
+ WEIGHT_DECAY = 0.05
670
+ BETA1 = 0.9
671
+ BETA2 = 0.999
672
+ MASK_PERCENT = 0.40
673
+ N_HEADS = 8
674
+ DROPOUT = 0.1
675
+ ```
676
+
677
+ - **Data Parameters**:
678
+ - **`N_ROWS` and `N_COLUMNS`**: Number of rows and columns in each channel patch (4 antennas × 4 subcarriers).
679
+ - **`ELEMENT_LENGTH`**: Number of elements in each patch, including real and imaginary parts (\(4 \times 4 \times 2 = 32\)).
680
+ - **`MAX_LEN`**: Maximum input length (including positional encoding).
681
+
682
+ - **Model Hyperparameters**:
683
+ - **`D_MODEL`**: Embedding size (128).
684
+ - **`N_LAYERS`**: Number of transformer layers (12).
685
+ - **`N_HEADS`**: Number of attention heads (8).
686
+ - **`DROPOUT`**: Dropout probability (0.1).
687
+
688
+ - **Training Hyperparameters**:
689
+ - **`EPOCHS`**: Total number of epochs (50).
690
+ - **`BATCH_SIZE`**: Batch size for training (128) and validation (64).
691
+ - **`BASE_LR` and `WARMUP_EPOCHS`**: Initial learning rate (5e-4) and warmup period (5 epochs).
692
+ - **`MASK_PERCENT`**: Percentage of masked patches during pretraining (40%).
693
+
694
+ ---
695
+
696
+ ### **Generating the Dataset**
697
+
698
+ The dataset is prepared by tokenizing scenarios using the `tokenizer` function:
699
+
700
+ ```python
701
+ bs_idxs = [1, 2, 3]
702
+ selected_scenario_names = scenarios_list()[:80]
703
+ preprocessed_data = tokenizer(
704
+ selected_scenario_names,
705
+ MAX_LEN,
706
+ masking_percent=MASK_PERCENT,
707
+ mask=True,
708
+ seed=42
709
+ )
710
+ ```
711
+
712
+ - **Parameters**:
713
+ - **`bs_idxs`**: Selects base stations 1, 2, and 3 for data generation.
714
+ - **`selected_scenario_names`**: Uses the first 80 scenarios from the `scenarios_list`.
715
+ - **`masking_percent`**: Masks 40% of patches in each channel during pretraining.
716
+
717
+ - **Outputs**:
718
+ - **`preprocessed_data`**: A dictionary where keys are scenario names, and values are preprocessed samples.
719
+
720
+ ---
721
+
722
+ ### **Splitting the Dataset**
723
+
724
+ Split the dataset into training, validation, and test sets:
725
+
726
+ ```python
727
+ SEED = 42
728
+ torch.manual_seed(SEED)
729
+ np.random.seed(SEED)
730
+ train_ratio = 0.8
731
+ val_ratio = 0.2
732
+ train_data = {}
733
+ val_data = {}
734
+ test_data = {}
735
+
736
+ for key, samples in preprocessed_data.items():
737
+ total_samples = len(samples)
738
+ train_size = int(train_ratio * total_samples)
739
+ val_size = int(val_ratio * total_samples)
740
+ test_size = total_samples - train_size - val_size
741
+
742
+ train_data[key], val_data[key], test_data[key] = random_split(
743
+ samples, [train_size, val_size, test_size]
744
+ )
745
+
746
+ train_loaders = create_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
747
+ val_loaders = create_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
748
+ ```
749
+
750
+ - **Data Ratios**:
751
+ - **`train_ratio`**: 80% of the data for training.
752
+ - **`val_ratio`**: 20% for validation.
753
+ - Remaining samples are reserved for testing.
754
+
755
+ - **Data Loaders**:
756
+ - `train_loaders` and `val_loaders` provide batched data for training and validation.
757
+
758
+ ---
759
+
760
+ ### **Initializing the Model**
761
+
762
+ Initialize **LWM-v1.1** and optionally load a pretrained checkpoint:
763
+
764
+ ```python
765
+ load_model = True
766
+ gpu_ids = [0]
767
+ device = torch.device("cuda:0")
768
+ model = lwm_model.lwm().to(device)
769
+
770
+ if load_model:
771
+ model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
772
+ state_dict = torch.load(f"models/{model_name}", map_location=device)
773
+ new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
774
+ model.load_state_dict(new_state_dict)
775
+
776
+ model = nn.DataParallel(model, gpu_ids)
777
+ print(f"Model loaded successfully on GPU {device.index}")
778
+ n_parameters = count_parameters(model)
779
+ print(f"Number of trainable parameters: {n_parameters:,}")
780
+ ```
781
+
782
+ - **GPU Handling**:
783
+ - The model runs on GPU `cuda:0`. It can also use multiple GPUs if specified.
784
+
785
+ - **Checkpoint Loading**:
786
+ - If `load_model` is `True`, a pretrained checkpoint is loaded, ensuring the model starts with learned weights.
787
+
788
+ - **Parameter Count**:
789
+ - Displays the number of trainable parameters for transparency.
790
+
791
+ ---
792
+
793
+ ### **Optimizer and Learning Rate Scheduler**
794
+
795
+ Define the optimizer and learning rate scheduler:
796
+
797
+ ```python
798
+ optimizer = AdamW(
799
+ model.parameters(),
800
+ lr=BASE_LR,
801
+ betas=(BETA1, BETA2),
802
+ weight_decay=WEIGHT_DECAY
803
+ )
804
+
805
+ def lr_lambda(current_step):
806
+ if current_step < WARMUP_STEPS:
807
+ return current_step / WARMUP_STEPS
808
+ else:
809
+ scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
810
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
811
+ return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
812
+
813
+ scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
814
+ ```
815
+
816
+ - **AdamW Optimizer**:
817
+ - Includes weight decay for better generalization.
818
+ - **Learning Rate Scheduler**:
819
+ - Combines linear warmup and cosine decay for smooth training.
820
+
821
+ ---
822
+
823
+ ### **Training the Model**
824
+
825
+ Train the model using the `train_lwm` function:
826
+
827
+ ```python
828
+ pretrained_model = train_lwm(
829
+ model,
830
+ train_loaders,
831
+ val_loaders,
832
+ optimizer,
833
+ scheduler,
834
+ EPOCHS,
835
+ device=device
836
+ )
837
+ ```
838
+
839
+ - **Inputs**:
840
+ - **`model`**: The initialized LWM model.
841
+ - **`train_loaders` and `val_loaders`**: Data loaders for training and validation.
842
+ - **`optimizer` and `scheduler`**: Configured optimizer and learning rate scheduler.
843
+ - **`EPOCHS`**: Number of training epochs.
844
+ - **`device`**: Specifies whether training occurs on GPU or CPU.
845
+
846
+ - **Output**:
847
+ - **`pretrained_model`**: The trained LWM-v1.1 model.
848
+
849
+ ---
850
+
851
+ ### **Explore the Interactive Demo**
852
 
853
  Experience **LWM** interactively via our Hugging Face Spaces demo:
854
  [**Try the Interactive Demo!**](https://huggingface.co/spaces/wi-lab/lwm-interactive-demo)