Update README.md
Browse files
README.md
CHANGED
@@ -20,7 +20,7 @@ LWM-v1.1 is a powerful **pre-trained** model developed as a **universal feature
|
|
20 |
|
21 |
### **How is LWM-v1.1 built?**
|
22 |
|
23 |
-
The LWM-v1.1 architecture is built on transformers, designed to capture **
|
24 |
|
25 |
### **What does LWM-v1.1 offer?**
|
26 |
|
@@ -692,7 +692,7 @@ DROPOUT = 0.1
|
|
692 |
|
693 |
- **Data Parameters**:
|
694 |
- **`N_ROWS` and `N_COLUMNS`**: Number of rows and columns in each channel patch (4 antennas × 4 subcarriers).
|
695 |
-
- **`ELEMENT_LENGTH`**: Number of elements in each patch, including real and imaginary parts (
|
696 |
- **`MAX_LEN`**: Maximum input length (including positional encoding).
|
697 |
|
698 |
- **Model Hyperparameters**:
|
@@ -784,7 +784,7 @@ device = torch.device("cuda:0")
|
|
784 |
model = lwm_model.lwm().to(device)
|
785 |
|
786 |
if load_model:
|
787 |
-
model_name = "
|
788 |
state_dict = torch.load(f"models/{model_name}", map_location=device)
|
789 |
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
790 |
model.load_state_dict(new_state_dict)
|
|
|
20 |
|
21 |
### **How is LWM-v1.1 built?**
|
22 |
|
23 |
+
The LWM-v1.1 architecture is built on transformers, designed to capture **dependencies** in wireless channel data. The model employs an updated version of **Masked Channel Modeling (MCM)**, increasing the masking ratio to make pretraining more challenging and effective. With **2D patch segmentation**, the model learns intricate relationships across both antennas and subcarriers, while **bucket-based batching** ensures efficient processing of variable-sized inputs. These enhancements make LWM-v1.1 highly scalable and adaptable, offering robust embeddings for diverse scenarios.
|
24 |
|
25 |
### **What does LWM-v1.1 offer?**
|
26 |
|
|
|
692 |
|
693 |
- **Data Parameters**:
|
694 |
- **`N_ROWS` and `N_COLUMNS`**: Number of rows and columns in each channel patch (4 antennas × 4 subcarriers).
|
695 |
+
- **`ELEMENT_LENGTH`**: Number of elements in each patch, including real and imaginary parts (4 * 4 * 2 = 32).
|
696 |
- **`MAX_LEN`**: Maximum input length (including positional encoding).
|
697 |
|
698 |
- **Model Hyperparameters**:
|
|
|
784 |
model = lwm_model.lwm().to(device)
|
785 |
|
786 |
if load_model:
|
787 |
+
model_name = "model.pth"
|
788 |
state_dict = torch.load(f"models/{model_name}", map_location=device)
|
789 |
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
790 |
model.load_state_dict(new_state_dict)
|