Upload the pre-trained model and pre-training, inference, downstream, and utility scripts
Browse files- .gitignore +2 -0
- downstream.py +146 -0
- inference.py +52 -0
- input_preprocess.py +1020 -0
- lwm_model.py +154 -0
- main.py +120 -0
- models/model.pth +3 -0
- train.py +446 -0
- utils.py +247 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__*
|
2 |
+
/images
|
downstream.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Fri Jan 10 11:11:58 2025
|
4 |
+
|
5 |
+
This script evaluates downstream task performance by comparing models trained
|
6 |
+
on raw channel representations versus those trained on LWM embeddings.
|
7 |
+
|
8 |
+
@author: Sadjad Alikhani
|
9 |
+
"""
|
10 |
+
#%% IMPORT PACKAGES & MODULES
|
11 |
+
from input_preprocess import tokenizer, scenarios_list
|
12 |
+
from inference import lwm_inference
|
13 |
+
from utils import prepare_loaders
|
14 |
+
from train import finetune
|
15 |
+
import lwm_model
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import warnings
|
21 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
22 |
+
#%% DOWNSTERAM DATA GENERATION
|
23 |
+
n_beams = 16
|
24 |
+
task = ['Beam Prediction', 'LoS/NLoS Classification'][1]
|
25 |
+
task_type = ["classification", "regression"][0]
|
26 |
+
visualization_method = ["pca", "umap", "tsne"][2]
|
27 |
+
input_types = ["cls_emb", "channel_emb", "raw"]
|
28 |
+
train_ratios = [.001, .01, .05, .1, .25, .5, .8]
|
29 |
+
fine_tuning_status = [None, ["layers.8", "layers.9", "layers.10", "layers.11"], "full"]
|
30 |
+
selected_scenario_names = [scenarios_list()[18]]
|
31 |
+
preprocessed_data, labels, raw_chs = tokenizer(
|
32 |
+
selected_scenario_names,
|
33 |
+
bs_idxs=[3],
|
34 |
+
load_data=False,
|
35 |
+
task=task,
|
36 |
+
n_beams=n_beams)
|
37 |
+
#%% LOAD THE MODEL
|
38 |
+
gpu_ids = [0]
|
39 |
+
device = torch.device("cuda:0")
|
40 |
+
model = lwm_model.lwm().to(device)
|
41 |
+
|
42 |
+
model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
|
43 |
+
state_dict = torch.load(f"models/{model_name}", map_location=device)
|
44 |
+
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
45 |
+
model.load_state_dict(new_state_dict)
|
46 |
+
|
47 |
+
model = nn.DataParallel(model, gpu_ids)
|
48 |
+
print(f"Model loaded successfully on GPU {device.index}")
|
49 |
+
#%% 2D EMBEDDING SPACE VISUALIZATIONN BEFORE FINE-TUNING
|
50 |
+
chs = lwm_inference(
|
51 |
+
model,
|
52 |
+
preprocessed_data,
|
53 |
+
input_type="cls_emb",
|
54 |
+
device=device,
|
55 |
+
batch_size=64,
|
56 |
+
visualization=False,
|
57 |
+
labels=labels,
|
58 |
+
visualization_method=visualization_method)
|
59 |
+
#%% FINE-TUNE
|
60 |
+
results = np.zeros((len(fine_tuning_status), len(input_types), len(train_ratios)))
|
61 |
+
for fine_tuning_stat_idx, fine_tuning_stat in enumerate(fine_tuning_status):
|
62 |
+
for input_type_idx, input_type in enumerate(input_types):
|
63 |
+
|
64 |
+
if input_type == "raw" and fine_tuning_stat is not None:
|
65 |
+
continue
|
66 |
+
|
67 |
+
selected_patches_idxs = None
|
68 |
+
for train_ratio_idx, train_ratio in enumerate(train_ratios):
|
69 |
+
|
70 |
+
print(f"\nfine-tuning status: {fine_tuning_stat}")
|
71 |
+
print(f"input type: {input_type}")
|
72 |
+
print(f"train ratio: {train_ratio}\n")
|
73 |
+
|
74 |
+
# PREPARE LOADERS
|
75 |
+
train_loader, val_loader, samples, target = prepare_loaders(
|
76 |
+
preprocessed_data=preprocessed_data,
|
77 |
+
labels=labels,
|
78 |
+
selected_patches_idxs=selected_patches_idxs,
|
79 |
+
input_type=input_type,
|
80 |
+
task_type=task_type,
|
81 |
+
train_ratio=train_ratio,
|
82 |
+
batch_size=128,
|
83 |
+
seed=42
|
84 |
+
)
|
85 |
+
|
86 |
+
# FINE-TUNE LWM
|
87 |
+
fine_tuned_model, best_model_path, train_losses, val_losses, f1_scores, attn_maps_ft = finetune(
|
88 |
+
base_model=model,
|
89 |
+
train_loader=train_loader,
|
90 |
+
val_loader=val_loader,
|
91 |
+
task_type=task_type,
|
92 |
+
input_type=input_type,
|
93 |
+
num_classes=n_beams if task=='Beam Prediction' else 2 if task=='LoS/NLoS Classification' else None,
|
94 |
+
output_dim=target.shape[-1] if task_type =='regression' else None,
|
95 |
+
use_custom_head=True,
|
96 |
+
fine_tune_layers=fine_tuning_stat,
|
97 |
+
optimizer_config={"lr": 1e-3},
|
98 |
+
epochs=15,
|
99 |
+
device=device,
|
100 |
+
task=task
|
101 |
+
)
|
102 |
+
|
103 |
+
results[fine_tuning_stat_idx][input_type_idx][train_ratio_idx] = f1_scores[-1]
|
104 |
+
|
105 |
+
markers = ['o', 's', 'D']
|
106 |
+
labels = ['CLS Emb', 'CHS Emb', 'Raw']
|
107 |
+
fine_tuning_status_labels = ['No FT', 'Partial FT', 'Full FT']
|
108 |
+
line_styles = ['-', '--', ':']
|
109 |
+
colors = plt.cm.viridis(np.linspace(0, 0.8, len(labels)))
|
110 |
+
plt.figure(figsize=(12, 8), dpi=500)
|
111 |
+
for ft_idx, (ft_status_label, line_style) in enumerate(zip(fine_tuning_status_labels, line_styles)):
|
112 |
+
for idx, (marker, label, color) in enumerate(zip(markers, labels, colors)):
|
113 |
+
# For "Raw Channels," only plot "No Fine-Tuning" case
|
114 |
+
if label == "Raw" and ft_status_label != "No FT":
|
115 |
+
continue
|
116 |
+
# Simplify label for "Raw Channels" without fine-tuning
|
117 |
+
plot_label = label if label != "Raw Channels" or ft_status_label != "No Fine-Tuning" else "Raw Channels"
|
118 |
+
plt.plot(
|
119 |
+
train_ratios,
|
120 |
+
results[ft_idx, idx],
|
121 |
+
marker=marker,
|
122 |
+
linestyle=line_style,
|
123 |
+
label=f"{plot_label} ({ft_status_label})" if label != "Raw Channels" else plot_label,
|
124 |
+
color=color,
|
125 |
+
linewidth=3,
|
126 |
+
markersize=9
|
127 |
+
)
|
128 |
+
plt.xscale('log')
|
129 |
+
plt.xlabel("Train Ratio", fontsize=20)
|
130 |
+
plt.ylabel("F1-Score", fontsize=20)
|
131 |
+
plt.legend(fontsize=17, loc="best")
|
132 |
+
plt.grid(True, linestyle="--", alpha=0.7)
|
133 |
+
plt.xticks(fontsize=17)
|
134 |
+
plt.yticks(fontsize=17)
|
135 |
+
plt.tight_layout()
|
136 |
+
plt.show()
|
137 |
+
#%% 2D EMBEDDING SPACE VISUALIZATIONN AFTER FINE-TUNING
|
138 |
+
chs = lwm_inference(
|
139 |
+
fine_tuned_model.model,
|
140 |
+
preprocessed_data,
|
141 |
+
input_type="cls_emb",
|
142 |
+
device=device,
|
143 |
+
batch_size=64,
|
144 |
+
visualization=False,
|
145 |
+
labels=labels,
|
146 |
+
visualization_method=visualization_method)
|
inference.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Sun Sep 15 18:27:17 2024
|
4 |
+
|
5 |
+
This scripts performs the LWM inference on raw channel representations.
|
6 |
+
|
7 |
+
@author: Sadjad Alikhani
|
8 |
+
"""
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import DataLoader, TensorDataset
|
11 |
+
from utils import visualize_embeddings
|
12 |
+
from tqdm import tqdm
|
13 |
+
import warnings
|
14 |
+
warnings.filterwarnings('ignore')
|
15 |
+
#%%
|
16 |
+
def lwm_inference(model, data, input_type="cls_emb", device="cpu", batch_size=64, visualization=False, labels=None, visualization_method="t-sne"):
|
17 |
+
|
18 |
+
if input_type == "raw":
|
19 |
+
output_total = data
|
20 |
+
else:
|
21 |
+
dataset = TensorDataset(data)
|
22 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
23 |
+
|
24 |
+
embeddings = []
|
25 |
+
model.eval()
|
26 |
+
with torch.no_grad():
|
27 |
+
with tqdm(dataloader, desc="Inference", unit="batch") as t:
|
28 |
+
for batch in t:
|
29 |
+
|
30 |
+
input_ids = batch[0].to(device)
|
31 |
+
output = model(input_ids)[0]
|
32 |
+
|
33 |
+
if input_type == "cls_emb":
|
34 |
+
batch_embeddings = output[:, 0, :]
|
35 |
+
embeddings.append(batch_embeddings)
|
36 |
+
elif input_type == "channel_emb":
|
37 |
+
batch_embeddings = output[:, 1:, :]
|
38 |
+
embeddings.append(batch_embeddings)
|
39 |
+
|
40 |
+
output_total = torch.cat(embeddings, dim=0).float()
|
41 |
+
|
42 |
+
if visualization:
|
43 |
+
visualize_embeddings(output_total.view(output_total.size(0), -1),
|
44 |
+
labels,
|
45 |
+
method=visualization_method,
|
46 |
+
label="Embedding Space")
|
47 |
+
visualize_embeddings(data.view(data.size(0), -1),
|
48 |
+
labels,
|
49 |
+
method=visualization_method,
|
50 |
+
label="Original Space")
|
51 |
+
|
52 |
+
return output_total
|
input_preprocess.py
ADDED
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Fri Sep 13 16:13:29 2024
|
4 |
+
|
5 |
+
This script generates preprocessed data from wireless communication scenarios,
|
6 |
+
including channel generation, patch generation, masking, and preparing raw
|
7 |
+
channels for the Transformer-based LWM model.
|
8 |
+
|
9 |
+
@author: Sadjad Alikhani
|
10 |
+
"""
|
11 |
+
import numpy as np
|
12 |
+
import os
|
13 |
+
from tqdm import tqdm
|
14 |
+
import time
|
15 |
+
import pickle
|
16 |
+
import DeepMIMOv3
|
17 |
+
import torch
|
18 |
+
from collections import defaultdict
|
19 |
+
from utils import generate_gaussian_noise, plot_coverage
|
20 |
+
#%% Scenarios List
|
21 |
+
def scenarios_list():
|
22 |
+
scen_list = np.array([
|
23 |
+
'city_0_newyork',
|
24 |
+
'city_1_losangeles',
|
25 |
+
'city_2_chicago',
|
26 |
+
'city_3_houston',
|
27 |
+
'city_4_phoenix',
|
28 |
+
'city_5_philadelphia',
|
29 |
+
'city_6_miami',
|
30 |
+
'city_7_sandiego',
|
31 |
+
'city_8_dallas',
|
32 |
+
'city_9_sanfrancisco',
|
33 |
+
'city_10_austin',
|
34 |
+
'city_11_santaclara',
|
35 |
+
'city_12_fortworth',
|
36 |
+
'city_13_columbus',
|
37 |
+
'city_14_charlotte',
|
38 |
+
'city_15_indianapolis',
|
39 |
+
'city_16_sanfrancisco',
|
40 |
+
'city_17_seattle',
|
41 |
+
'city_18_denver',
|
42 |
+
'city_19_oklahoma',
|
43 |
+
'asu_campus1_v1',
|
44 |
+
'asu_campus1_v2',
|
45 |
+
'asu_campus1_v3',
|
46 |
+
'asu_campus1_v4',
|
47 |
+
'asu_campus1_v5',
|
48 |
+
'asu_campus1_v6',
|
49 |
+
'asu_campus1_v7',
|
50 |
+
'asu_campus1_v8',
|
51 |
+
'asu_campus1_v9',
|
52 |
+
'asu_campus1_v10',
|
53 |
+
'asu_campus1_v11',
|
54 |
+
'asu_campus1_v12',
|
55 |
+
'asu_campus1_v13',
|
56 |
+
'asu_campus1_v14',
|
57 |
+
'asu_campus1_v15',
|
58 |
+
'asu_campus1_v16',
|
59 |
+
'asu_campus1_v17',
|
60 |
+
'asu_campus1_v18',
|
61 |
+
'asu_campus1_v19',
|
62 |
+
'asu_campus1_v20',
|
63 |
+
'Boston5G_3p5_v1',
|
64 |
+
'Boston5G_3p5_v2',
|
65 |
+
'Boston5G_3p5_v3',
|
66 |
+
'Boston5G_3p5_v4',
|
67 |
+
'Boston5G_3p5_v5',
|
68 |
+
'Boston5G_3p5_v6',
|
69 |
+
'Boston5G_3p5_v7',
|
70 |
+
'Boston5G_3p5_v8',
|
71 |
+
'Boston5G_3p5_v9',
|
72 |
+
'Boston5G_3p5_v10',
|
73 |
+
'Boston5G_3p5_v11',
|
74 |
+
'Boston5G_3p5_v12',
|
75 |
+
'Boston5G_3p5_v13',
|
76 |
+
'Boston5G_3p5_v14',
|
77 |
+
'Boston5G_3p5_v15',
|
78 |
+
'Boston5G_3p5_v16',
|
79 |
+
'Boston5G_3p5_v17',
|
80 |
+
'Boston5G_3p5_v18',
|
81 |
+
'Boston5G_3p5_v19',
|
82 |
+
'Boston5G_3p5_v20',
|
83 |
+
'O1_3p5_v1',
|
84 |
+
'O1_3p5_v2',
|
85 |
+
'O1_3p5_v3',
|
86 |
+
'O1_3p5_v4',
|
87 |
+
'O1_3p5_v5',
|
88 |
+
'O1_3p5_v6',
|
89 |
+
'O1_3p5_v7',
|
90 |
+
'O1_3p5_v8',
|
91 |
+
'O1_3p5_v9',
|
92 |
+
'O1_3p5_v10',
|
93 |
+
'O1_3p5_v11',
|
94 |
+
'O1_3p5_v12',
|
95 |
+
'O1_3p5_v13',
|
96 |
+
'O1_3p5_v14',
|
97 |
+
'O1_3p5_v15',
|
98 |
+
'O1_3p5_v16',
|
99 |
+
'O1_3p5_v17',
|
100 |
+
'O1_3p5_v18',
|
101 |
+
'O1_3p5_v19',
|
102 |
+
'O1_3p5_v20',
|
103 |
+
'asu_campus1',
|
104 |
+
'O1_3p5',
|
105 |
+
'Boston5G_3p5',
|
106 |
+
'city_0_newyork_v16x64',
|
107 |
+
'city_1_losangeles_v16x64',
|
108 |
+
'city_2_chicago_v16x64',
|
109 |
+
'city_3_houston_v16x64',
|
110 |
+
'city_4_phoenix_v16x64',
|
111 |
+
'city_5_philadelphia_v16x64',
|
112 |
+
'city_6_miami_v16x64',
|
113 |
+
'city_7_sandiego_v16x64',
|
114 |
+
'city_8_dallas_v16x64',
|
115 |
+
'city_9_sanfrancisco_v16x64'
|
116 |
+
])
|
117 |
+
return scen_list
|
118 |
+
#%% Token Generation
|
119 |
+
def patch_gen(N_ROWS=4, N_COLUMNS=4, selected_scenario_names=None,
|
120 |
+
manual_data=None, bs_idxs=[1,2,3], load_data=False,
|
121 |
+
save_dir="data", task="LoS/NLoS Classification",
|
122 |
+
n_beams=64, o1_bs_idx=[4]):
|
123 |
+
|
124 |
+
os.makedirs(save_dir, exist_ok=True)
|
125 |
+
|
126 |
+
if manual_data is not None:
|
127 |
+
patches = patch_maker(np.expand_dims(np.array(manual_data), axis=1))
|
128 |
+
else:
|
129 |
+
deepmimo_data = []
|
130 |
+
for scenario_name in selected_scenario_names:
|
131 |
+
if "O1" in scenario_name: # make an exception for bs idxs of the o1 scenario
|
132 |
+
if o1_bs_idx is None:
|
133 |
+
bs_idxs = [4, 15]
|
134 |
+
else:
|
135 |
+
bs_idxs = o1_bs_idx
|
136 |
+
for bs_idx in bs_idxs:
|
137 |
+
if has_version_suffix(scenario_name) and bs_idx in [2,3]:
|
138 |
+
continue
|
139 |
+
if not load_data:
|
140 |
+
print(f"\nGenerating data for scenario: {scenario_name}, BS #{bs_idx}")
|
141 |
+
data, n_ant_bs, n_subcarriers = DeepMIMO_data_gen(scenario_name, bs_idx)
|
142 |
+
file_name = f"{save_dir}/{scenario_name}_ant{n_ant_bs}_sub{n_subcarriers}_bs{bs_idx}.npy"
|
143 |
+
np.save(file_name, data)
|
144 |
+
print(f"Data saved to {file_name}")
|
145 |
+
deepmimo_data.append(data)
|
146 |
+
else:
|
147 |
+
n_ant_bs, n_subcarriers = parametersv2(scenario_name, bs_idx)
|
148 |
+
print(f"\nLoading data for scenario: {scenario_name}, BS #{bs_idx}")
|
149 |
+
file_name = f"{save_dir}/{scenario_name}_ant{n_ant_bs}_sub{n_subcarriers}_bs{bs_idx}.npy"
|
150 |
+
data = np.load(file_name, allow_pickle=True).item()
|
151 |
+
print(f"Data loaded from {file_name}")
|
152 |
+
deepmimo_data.append(data)
|
153 |
+
|
154 |
+
cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(len(deepmimo_data))] #n_scenarios*n_bs_idxs
|
155 |
+
patches = [patch_maker(cleaned_deepmimo_data[scenario_idx], N_ROWS, N_COLUMNS) for scenario_idx in range(len(deepmimo_data))]
|
156 |
+
raw_chs = torch.tensor(cleaned_deepmimo_data[0]).squeeze(1)
|
157 |
+
raw_chs = raw_chs.view(raw_chs.size(0), -1)
|
158 |
+
raw_chs = torch.hstack((raw_chs.real, raw_chs.imag))
|
159 |
+
|
160 |
+
if task:
|
161 |
+
labels = [label_gen(task, deepmimo_data[scenario_idx], selected_scenario_names[scenario_idx], n_beams=n_beams) for scenario_idx in range(len(deepmimo_data))]
|
162 |
+
return patches, torch.tensor(labels[0]), raw_chs.view(raw_chs.size(0), -1)
|
163 |
+
else:
|
164 |
+
return patches, raw_chs.view(raw_chs.size(0), -1)
|
165 |
+
#%%
|
166 |
+
def tokenizer(selected_scenario_names,
|
167 |
+
bs_idxs=[1,2,3],
|
168 |
+
load_data=False,
|
169 |
+
task="LoS/NLoS Classification",
|
170 |
+
n_beams=64,
|
171 |
+
MAX_LEN=513,
|
172 |
+
masking_percent=.40,
|
173 |
+
mask=False,
|
174 |
+
seed=42,
|
175 |
+
snr=None):
|
176 |
+
|
177 |
+
patches, labels, raw_chs = patch_gen(
|
178 |
+
selected_scenario_names=selected_scenario_names,
|
179 |
+
bs_idxs=bs_idxs,
|
180 |
+
load_data=load_data,
|
181 |
+
task=task,
|
182 |
+
n_beams=n_beams
|
183 |
+
)
|
184 |
+
|
185 |
+
patches = [patch for patch_list in patches for patch in patch_list]
|
186 |
+
print("Total number of samples:", len(patches))
|
187 |
+
|
188 |
+
grouped_data = defaultdict(list) # Group samples by sequence length
|
189 |
+
grouped_data_2 = []
|
190 |
+
|
191 |
+
for user_idx in tqdm(range(len(patches)), desc="Processing items"):
|
192 |
+
patch_size = patches[user_idx].shape[1]
|
193 |
+
n_patches = patches[user_idx].shape[0]
|
194 |
+
n_masks_half = int(masking_percent * n_patches)
|
195 |
+
|
196 |
+
word2id = {
|
197 |
+
'[CLS]': 0.2 * np.ones((patch_size)),
|
198 |
+
'[MASK]': 0.1 * np.ones((patch_size))
|
199 |
+
}
|
200 |
+
|
201 |
+
sample = make_sample(
|
202 |
+
user_idx, patches, word2id, n_patches, n_masks_half, patch_size, MAX_LEN, mask=mask, seed=seed
|
203 |
+
)
|
204 |
+
|
205 |
+
if mask:
|
206 |
+
seq_length = len(sample[0])
|
207 |
+
grouped_data[seq_length].append(sample)
|
208 |
+
else:
|
209 |
+
grouped_data_2.append(sample)
|
210 |
+
|
211 |
+
if mask:
|
212 |
+
# Normalize keys to 0, 1, 2, ...
|
213 |
+
normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
|
214 |
+
else:
|
215 |
+
normalized_grouped_data = torch.stack(grouped_data_2, dim=0)
|
216 |
+
# normalized_grouped_data = grouped_data_2
|
217 |
+
if snr is not None:
|
218 |
+
normalized_grouped_data += generate_gaussian_noise(normalized_grouped_data, snr)
|
219 |
+
# normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
|
220 |
+
|
221 |
+
return normalized_grouped_data, labels, raw_chs
|
222 |
+
#%% REMOVE ZERO CHANNELS AND SCALE
|
223 |
+
def deepmimo_data_cleaning(deepmimo_data):
|
224 |
+
idxs = np.where(deepmimo_data['user']['LoS'] != -1)[0]
|
225 |
+
cleaned_deepmimo_data = deepmimo_data['user']['channel'][idxs]
|
226 |
+
return np.array(cleaned_deepmimo_data) * 1e6
|
227 |
+
#%%
|
228 |
+
def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, MAX_LEN, mask=True, seed=None):
|
229 |
+
|
230 |
+
if seed is not None:
|
231 |
+
np.random.seed(seed)
|
232 |
+
|
233 |
+
# Step 1: Retrieve tokens and prepend [CLS]
|
234 |
+
tokens = patch[user_idx]
|
235 |
+
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
236 |
+
|
237 |
+
# Step 2: Mask real and imaginary patches
|
238 |
+
tokens_size = int(n_patches) # int(n_patches / 2)
|
239 |
+
masked_pos = np.random.choice(range(1, tokens_size), size=n_masks, replace=False)
|
240 |
+
|
241 |
+
masked_tokens = []
|
242 |
+
for pos in masked_pos:
|
243 |
+
original_masked_tokens = input_ids[pos].copy()
|
244 |
+
masked_tokens.append(original_masked_tokens)
|
245 |
+
if mask:
|
246 |
+
rnd_num = np.random.rand()
|
247 |
+
if rnd_num < 0.1:
|
248 |
+
input_ids[pos] = np.random.rand(patch_size) # Replace with random values
|
249 |
+
elif rnd_num < 0.9:
|
250 |
+
input_ids[pos] = word2id['[MASK]'] # Replace with [MASK]
|
251 |
+
|
252 |
+
if not mask:
|
253 |
+
return torch.tensor(input_ids)
|
254 |
+
else:
|
255 |
+
return [input_ids, masked_tokens, masked_pos]
|
256 |
+
#%% Patch GENERATION
|
257 |
+
def patch_maker(original_ch, patch_rows, patch_cols):
|
258 |
+
# Step 1: Remove the singleton channel dimension
|
259 |
+
n_samples, _, n_rows, n_cols = original_ch.shape # Unpack shape
|
260 |
+
original_ch = original_ch[:, 0] # Remove the singleton dimension
|
261 |
+
|
262 |
+
# Step 2: Split into real and imaginary parts and interleave them
|
263 |
+
flat_real = original_ch.real
|
264 |
+
flat_imag = original_ch.imag
|
265 |
+
|
266 |
+
# Interleave real and imaginary parts along the last axis
|
267 |
+
interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
|
268 |
+
interleaved[:, :, 0::2] = flat_real
|
269 |
+
interleaved[:, :, 1::2] = flat_imag
|
270 |
+
|
271 |
+
# Step 3: Compute the number of patches along rows and columns
|
272 |
+
n_patches_rows = int(np.ceil(n_rows / patch_rows))
|
273 |
+
n_patches_cols = int(np.ceil(n_cols / patch_cols))
|
274 |
+
|
275 |
+
# Step 4: Pad the matrix if necessary to make it divisible by patch size
|
276 |
+
padded_rows = n_patches_rows * patch_rows - n_rows
|
277 |
+
padded_cols = n_patches_cols * patch_cols - n_cols
|
278 |
+
if padded_rows > 0 or padded_cols > 0:
|
279 |
+
interleaved = np.pad(
|
280 |
+
interleaved,
|
281 |
+
((0, 0), (0, padded_rows), (0, padded_cols * 2)), # Double padding for interleaved axis
|
282 |
+
mode='constant',
|
283 |
+
constant_values=0,
|
284 |
+
)
|
285 |
+
|
286 |
+
# Step 5: Create patches by dividing into blocks
|
287 |
+
n_samples, padded_rows, padded_cols = interleaved.shape
|
288 |
+
padded_cols //= 2 # Adjust for interleaving (real and imaginary parts count as one)
|
289 |
+
patches = []
|
290 |
+
|
291 |
+
for i in range(0, padded_rows, patch_rows):
|
292 |
+
for j in range(0, padded_cols, patch_cols):
|
293 |
+
patch = interleaved[:, i:i + patch_rows, j * 2:(j + patch_cols) * 2]
|
294 |
+
patches.append(patch.reshape(n_samples, -1)) # Flatten each patch
|
295 |
+
|
296 |
+
# Step 6: Stack patches to form the final array
|
297 |
+
patches = np.stack(patches, axis=1) # Shape: (num_samples, n_patches, patch_rows * patch_cols * 2)
|
298 |
+
|
299 |
+
return patches
|
300 |
+
#%% Data Generation for Scenario Areas
|
301 |
+
def DeepMIMO_data_gen(scenario, bs_idx):
|
302 |
+
import DeepMIMOv3
|
303 |
+
parameters, row_column_users = get_parameters(scenario, bs_idx)
|
304 |
+
deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
|
305 |
+
|
306 |
+
if "O1" in scenario:
|
307 |
+
hops = [2, 2]
|
308 |
+
else:
|
309 |
+
hops = [1, 1]
|
310 |
+
|
311 |
+
uniform_idxs = uniform_sampling(deepMIMO_dataset, hops, len(parameters['user_rows']),
|
312 |
+
users_per_row=row_column_users[scenario]['n_per_row'])
|
313 |
+
data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
|
314 |
+
|
315 |
+
n_ant_bs = parameters['bs_antenna']['shape'][0]
|
316 |
+
n_subcarriers = parameters['OFDM']['subcarriers']
|
317 |
+
|
318 |
+
return data, n_ant_bs, n_subcarriers
|
319 |
+
#%%
|
320 |
+
def parametersv2(scenario, bs_idx):
|
321 |
+
parameters, _ = get_parameters(scenario, bs_idx)
|
322 |
+
n_ant_bs = parameters['bs_antenna']['shape'][0]
|
323 |
+
n_subcarriers = parameters['OFDM']['subcarriers']
|
324 |
+
return n_ant_bs, n_subcarriers
|
325 |
+
#%%%
|
326 |
+
def get_parameters(scenario, bs_idx=1):
|
327 |
+
|
328 |
+
n_ant_ue = 1
|
329 |
+
scs = 30e3
|
330 |
+
|
331 |
+
row_column_users = scenario_prop()
|
332 |
+
|
333 |
+
parameters = DeepMIMOv3.default_params()
|
334 |
+
parameters['dataset_folder'] = './scenarios'
|
335 |
+
parameters['scenario'] = scenario.split("_v")[0]
|
336 |
+
|
337 |
+
n_ant_bs = row_column_users[scenario]['n_ant_bs']
|
338 |
+
n_subcarriers = row_column_users[scenario]['n_subcarriers']
|
339 |
+
parameters['active_BS'] = np.array([bs_idx])
|
340 |
+
|
341 |
+
if isinstance(row_column_users[scenario]['n_rows'], int):
|
342 |
+
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
|
343 |
+
else:
|
344 |
+
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
|
345 |
+
row_column_users[scenario]['n_rows'][1])
|
346 |
+
|
347 |
+
parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
|
348 |
+
parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
|
349 |
+
parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
|
350 |
+
parameters['enable_BS2BS'] = False
|
351 |
+
parameters['OFDM']['subcarriers'] = n_subcarriers
|
352 |
+
parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
|
353 |
+
|
354 |
+
parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
|
355 |
+
parameters['num_paths'] = 20
|
356 |
+
|
357 |
+
return parameters, row_column_users
|
358 |
+
#%% Sampling and Data Selection
|
359 |
+
def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
|
360 |
+
cols = np.arange(users_per_row, step=sampling_div[0])
|
361 |
+
rows = np.arange(n_rows, step=sampling_div[1])
|
362 |
+
uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
|
363 |
+
return uniform_idxs
|
364 |
+
|
365 |
+
def select_by_idx(dataset, idxs):
|
366 |
+
dataset_t = [] # Trimmed dataset
|
367 |
+
for bs_idx in range(len(dataset)):
|
368 |
+
dataset_t.append({})
|
369 |
+
for key in dataset[bs_idx].keys():
|
370 |
+
dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
|
371 |
+
dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
|
372 |
+
return dataset_t
|
373 |
+
#%%
|
374 |
+
def inverse_patch_maker(patches, original_shape, patch_rows, patch_cols):
|
375 |
+
"""
|
376 |
+
Reconstructs the original channel matrix from patches.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
patches (numpy array): Patches of shape (num_samples, n_patches, patch_rows * patch_cols * 2).
|
380 |
+
original_shape (tuple): Original shape of the channel matrix (num_samples, 1, n_rows, n_cols).
|
381 |
+
patch_rows (int): Number of rows in each patch.
|
382 |
+
patch_cols (int): Number of columns in each patch.
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
numpy array: Reconstructed complex-valued channel matrix of shape (num_samples, 1, n_rows, n_cols).
|
386 |
+
"""
|
387 |
+
n_samples, n_patches, patch_size = patches.shape
|
388 |
+
_, _, n_rows, n_cols = original_shape
|
389 |
+
|
390 |
+
# Ensure patch dimensions match
|
391 |
+
assert patch_rows * patch_cols * 2 == patch_size, "Patch size mismatch with provided dimensions."
|
392 |
+
|
393 |
+
# Compute the number of patches along rows and columns
|
394 |
+
n_patches_rows = int(np.ceil(n_rows / patch_rows))
|
395 |
+
n_patches_cols = int(np.ceil(n_cols / patch_cols))
|
396 |
+
|
397 |
+
# Reassemble interleaved array from patches
|
398 |
+
interleaved = np.zeros((n_samples, n_patches_rows * patch_rows, n_patches_cols * patch_cols * 2), dtype=np.float32)
|
399 |
+
patch_idx = 0
|
400 |
+
|
401 |
+
for i in range(n_patches_rows):
|
402 |
+
for j in range(n_patches_cols):
|
403 |
+
patch = patches[:, patch_idx, :].reshape(n_samples, patch_rows, patch_cols * 2)
|
404 |
+
interleaved[:, i * patch_rows:(i + 1) * patch_rows, j * patch_cols * 2:(j + 1) * patch_cols * 2] = patch
|
405 |
+
patch_idx += 1
|
406 |
+
|
407 |
+
# Remove padding if necessary
|
408 |
+
interleaved = interleaved[:, :n_rows, :n_cols * 2]
|
409 |
+
|
410 |
+
# Separate real and imaginary parts
|
411 |
+
flat_real = interleaved[:, :, 0::2]
|
412 |
+
flat_imag = interleaved[:, :, 1::2]
|
413 |
+
|
414 |
+
# Reconstruct the complex-valued original channel
|
415 |
+
reconstructed = flat_real + 1j * flat_imag
|
416 |
+
|
417 |
+
# Add the singleton channel dimension back
|
418 |
+
reconstructed = reconstructed[:, np.newaxis, :, :] # Shape: (num_samples, 1, n_rows, n_cols)
|
419 |
+
|
420 |
+
return reconstructed
|
421 |
+
#%%
|
422 |
+
def label_gen(task, data, scenario, n_beams=64):
|
423 |
+
|
424 |
+
idxs = np.where(data['user']['LoS'] != -1)[0]
|
425 |
+
|
426 |
+
if task == 'LoS/NLoS Classification':
|
427 |
+
label = data['user']['LoS'][idxs]
|
428 |
+
|
429 |
+
losChs = np.where(data['user']['LoS'] == -1, np.nan, data['user']['LoS'])
|
430 |
+
plot_coverage(data['user']['location'], losChs, cbar_title='LoS status')
|
431 |
+
|
432 |
+
elif task == 'Beam Prediction':
|
433 |
+
parameters, row_column_users = get_parameters(scenario, bs_idx=1)
|
434 |
+
n_users = len(data['user']['channel'])
|
435 |
+
n_subbands = 1
|
436 |
+
fov = 180
|
437 |
+
|
438 |
+
# Setup Beamformers
|
439 |
+
beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
|
440 |
+
|
441 |
+
F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
|
442 |
+
phi=azi*np.pi/180,
|
443 |
+
kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
|
444 |
+
for azi in beam_angles])
|
445 |
+
|
446 |
+
full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
|
447 |
+
for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
|
448 |
+
if data['user']['LoS'][ue_idx] == -1:
|
449 |
+
full_dbm[:,:,ue_idx] = np.nan
|
450 |
+
else:
|
451 |
+
chs = F1 @ data['user']['channel'][ue_idx]
|
452 |
+
full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
|
453 |
+
full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
|
454 |
+
|
455 |
+
best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
|
456 |
+
best_beams = best_beams.astype(float)
|
457 |
+
best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
|
458 |
+
# max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
|
459 |
+
|
460 |
+
plot_coverage(data['user']['location'], best_beams, tx_pos=data['location'],
|
461 |
+
tx_ori=parameters['bs_antenna']['rotation']*np.pi/180,
|
462 |
+
cbar_title='Best beam index')
|
463 |
+
|
464 |
+
label = best_beams[idxs]
|
465 |
+
|
466 |
+
return label.astype(int)
|
467 |
+
#%%
|
468 |
+
def steering_vec(array, phi=0, theta=0, kd=np.pi):
|
469 |
+
idxs = DeepMIMOv3.ant_indices(array)
|
470 |
+
resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
|
471 |
+
return resp / np.linalg.norm(resp)
|
472 |
+
#%%
|
473 |
+
import re
|
474 |
+
def has_version_suffix(s):
|
475 |
+
pattern = r"_v([1-9]|1[0-9]|20)$"
|
476 |
+
return bool(re.search(pattern, s))
|
477 |
+
#%%
|
478 |
+
def scenario_prop():
|
479 |
+
row_column_users = {
|
480 |
+
'city_0_newyork': {
|
481 |
+
'n_rows': 109,
|
482 |
+
'n_per_row': 291,
|
483 |
+
'n_ant_bs': 8,
|
484 |
+
'n_subcarriers': 32
|
485 |
+
},
|
486 |
+
'city_1_losangeles': {
|
487 |
+
'n_rows': 142,
|
488 |
+
'n_per_row': 201,
|
489 |
+
'n_ant_bs': 8,
|
490 |
+
'n_subcarriers': 64
|
491 |
+
},
|
492 |
+
'city_2_chicago': {
|
493 |
+
'n_rows': 139,
|
494 |
+
'n_per_row': 200,
|
495 |
+
'n_ant_bs': 8,
|
496 |
+
'n_subcarriers': 128
|
497 |
+
},
|
498 |
+
'city_3_houston': {
|
499 |
+
'n_rows': 154,
|
500 |
+
'n_per_row': 202,
|
501 |
+
'n_ant_bs': 8,
|
502 |
+
'n_subcarriers': 256
|
503 |
+
},
|
504 |
+
'city_4_phoenix': {
|
505 |
+
'n_rows': 198,
|
506 |
+
'n_per_row': 214,
|
507 |
+
'n_ant_bs': 8,
|
508 |
+
'n_subcarriers': 512
|
509 |
+
},
|
510 |
+
'city_5_philadelphia': {
|
511 |
+
'n_rows': 239,
|
512 |
+
'n_per_row': 164,
|
513 |
+
'n_ant_bs': 8,
|
514 |
+
'n_subcarriers': 1024
|
515 |
+
},
|
516 |
+
'city_6_miami': {
|
517 |
+
'n_rows': 199,
|
518 |
+
'n_per_row': 216 ,
|
519 |
+
'n_ant_bs': 16,
|
520 |
+
'n_subcarriers': 32
|
521 |
+
},
|
522 |
+
'city_7_sandiego': {
|
523 |
+
'n_rows': 207,
|
524 |
+
'n_per_row': 176,
|
525 |
+
'n_ant_bs': 16,
|
526 |
+
'n_subcarriers': 64
|
527 |
+
},
|
528 |
+
'city_8_dallas': {
|
529 |
+
'n_rows': 207,
|
530 |
+
'n_per_row': 190,
|
531 |
+
'n_ant_bs': 16,
|
532 |
+
'n_subcarriers': 128
|
533 |
+
},
|
534 |
+
'city_9_sanfrancisco': {
|
535 |
+
'n_rows': 196,
|
536 |
+
'n_per_row': 206,
|
537 |
+
'n_ant_bs': 16,
|
538 |
+
'n_subcarriers': 256
|
539 |
+
},
|
540 |
+
'city_10_austin': {
|
541 |
+
'n_rows': 255,
|
542 |
+
'n_per_row': 137,
|
543 |
+
'n_ant_bs': 16,
|
544 |
+
'n_subcarriers': 512
|
545 |
+
},
|
546 |
+
'city_11_santaclara': {
|
547 |
+
'n_rows': 117,
|
548 |
+
'n_per_row': 285,
|
549 |
+
'n_ant_bs': 32,
|
550 |
+
'n_subcarriers': 32
|
551 |
+
},
|
552 |
+
'city_12_fortworth': {
|
553 |
+
'n_rows': 214,
|
554 |
+
'n_per_row': 179,
|
555 |
+
'n_ant_bs': 32,
|
556 |
+
'n_subcarriers': 64
|
557 |
+
},
|
558 |
+
'city_13_columbus': {
|
559 |
+
'n_rows': 178,
|
560 |
+
'n_per_row': 240,
|
561 |
+
'n_ant_bs': 32,
|
562 |
+
'n_subcarriers': 128
|
563 |
+
},
|
564 |
+
'city_14_charlotte': {
|
565 |
+
'n_rows': 216,
|
566 |
+
'n_per_row': 177,
|
567 |
+
'n_ant_bs': 32,
|
568 |
+
'n_subcarriers': 256
|
569 |
+
},
|
570 |
+
'city_15_indianapolis': {
|
571 |
+
'n_rows': 200,
|
572 |
+
'n_per_row': 196,
|
573 |
+
'n_ant_bs': 64,
|
574 |
+
'n_subcarriers': 32
|
575 |
+
},
|
576 |
+
'city_16_sanfrancisco': {
|
577 |
+
'n_rows': 201,
|
578 |
+
'n_per_row': 208,
|
579 |
+
'n_ant_bs': 64,
|
580 |
+
'n_subcarriers': 64
|
581 |
+
},
|
582 |
+
'city_17_seattle': {
|
583 |
+
'n_rows': 185,
|
584 |
+
'n_per_row': 205,
|
585 |
+
'n_ant_bs': 64,
|
586 |
+
'n_subcarriers': 128
|
587 |
+
},
|
588 |
+
'city_18_denver': {
|
589 |
+
'n_rows': 212,
|
590 |
+
'n_per_row': 204,
|
591 |
+
'n_ant_bs': 128,
|
592 |
+
'n_subcarriers': 32
|
593 |
+
},
|
594 |
+
'city_19_oklahoma': {
|
595 |
+
'n_rows': 204,
|
596 |
+
'n_per_row': 188,
|
597 |
+
'n_ant_bs': 128,
|
598 |
+
'n_subcarriers': 64
|
599 |
+
},
|
600 |
+
'asu_campus1_v1': {
|
601 |
+
'n_rows': [0, 1*int(321/20)],
|
602 |
+
'n_per_row': 411,
|
603 |
+
'n_ant_bs': 8,
|
604 |
+
'n_subcarriers': 32
|
605 |
+
},
|
606 |
+
'asu_campus1_v2': {
|
607 |
+
'n_rows': [1*int(321/20), 2*int(321/20)],
|
608 |
+
'n_per_row': 411,
|
609 |
+
'n_ant_bs': 8,
|
610 |
+
'n_subcarriers': 64
|
611 |
+
},
|
612 |
+
'asu_campus1_v3': {
|
613 |
+
'n_rows': [2*int(321/20), 3*int(321/20)],
|
614 |
+
'n_per_row': 411,
|
615 |
+
'n_ant_bs': 8,
|
616 |
+
'n_subcarriers': 128
|
617 |
+
},
|
618 |
+
'asu_campus1_v4': {
|
619 |
+
'n_rows': [3*int(321/20), 4*int(321/20)],
|
620 |
+
'n_per_row': 411,
|
621 |
+
'n_ant_bs': 8,
|
622 |
+
'n_subcarriers': 256
|
623 |
+
},
|
624 |
+
'asu_campus1_v5': {
|
625 |
+
'n_rows': [4*int(321/20), 5*int(321/20)],
|
626 |
+
'n_per_row': 411,
|
627 |
+
'n_ant_bs': 8,
|
628 |
+
'n_subcarriers': 512
|
629 |
+
},
|
630 |
+
'asu_campus1_v6': {
|
631 |
+
'n_rows': [5*int(321/20), 6*int(321/20)],
|
632 |
+
'n_per_row': 411,
|
633 |
+
'n_ant_bs': 8,
|
634 |
+
'n_subcarriers': 1024
|
635 |
+
},
|
636 |
+
'asu_campus1_v7': {
|
637 |
+
'n_rows': [6*int(321/20), 7*int(321/20)],
|
638 |
+
'n_per_row': 411,
|
639 |
+
'n_ant_bs': 16,
|
640 |
+
'n_subcarriers': 32
|
641 |
+
},
|
642 |
+
'asu_campus1_v8': {
|
643 |
+
'n_rows': [7*int(321/20), 8*int(321/20)],
|
644 |
+
'n_per_row': 411,
|
645 |
+
'n_ant_bs':16,
|
646 |
+
'n_subcarriers': 64
|
647 |
+
},
|
648 |
+
'asu_campus1_v9': {
|
649 |
+
'n_rows': [8*int(321/20), 9*int(321/20)],
|
650 |
+
'n_per_row': 411,
|
651 |
+
'n_ant_bs': 16,
|
652 |
+
'n_subcarriers': 128
|
653 |
+
},
|
654 |
+
'asu_campus1_v10': {
|
655 |
+
'n_rows': [9*int(321/20), 10*int(321/20)],
|
656 |
+
'n_per_row': 411,
|
657 |
+
'n_ant_bs': 16,
|
658 |
+
'n_subcarriers': 256
|
659 |
+
},
|
660 |
+
'asu_campus1_v11': {
|
661 |
+
'n_rows': [10*int(321/20), 11*int(321/20)],
|
662 |
+
'n_per_row': 411,
|
663 |
+
'n_ant_bs': 16,
|
664 |
+
'n_subcarriers': 512
|
665 |
+
},
|
666 |
+
'asu_campus1_v12': {
|
667 |
+
'n_rows': [11*int(321/20), 12*int(321/20)],
|
668 |
+
'n_per_row': 411,
|
669 |
+
'n_ant_bs': 32,
|
670 |
+
'n_subcarriers': 32
|
671 |
+
},
|
672 |
+
'asu_campus1_v13': {
|
673 |
+
'n_rows': [12*int(321/20), 13*int(321/20)],
|
674 |
+
'n_per_row': 411,
|
675 |
+
'n_ant_bs': 32,
|
676 |
+
'n_subcarriers': 64
|
677 |
+
},
|
678 |
+
'asu_campus1_v14': {
|
679 |
+
'n_rows': [13*int(321/20), 14*int(321/20)],
|
680 |
+
'n_per_row': 411,
|
681 |
+
'n_ant_bs': 32,
|
682 |
+
'n_subcarriers': 128
|
683 |
+
},
|
684 |
+
'asu_campus1_v15': {
|
685 |
+
'n_rows': [14*int(321/20), 15*int(321/20)],
|
686 |
+
'n_per_row': 411,
|
687 |
+
'n_ant_bs': 32,
|
688 |
+
'n_subcarriers': 256
|
689 |
+
},
|
690 |
+
'asu_campus1_v16': {
|
691 |
+
'n_rows': [15*int(321/20), 16*int(321/20)],
|
692 |
+
'n_per_row': 411,
|
693 |
+
'n_ant_bs': 64,
|
694 |
+
'n_subcarriers': 32
|
695 |
+
},
|
696 |
+
'asu_campus1_v17': {
|
697 |
+
'n_rows': [16*int(321/20), 17*int(321/20)],
|
698 |
+
'n_per_row': 411,
|
699 |
+
'n_ant_bs': 64,
|
700 |
+
'n_subcarriers': 64
|
701 |
+
},
|
702 |
+
'asu_campus1_v18': {
|
703 |
+
'n_rows': [17*int(321/20), 18*int(321/20)],
|
704 |
+
'n_per_row': 411,
|
705 |
+
'n_ant_bs': 64,
|
706 |
+
'n_subcarriers': 128
|
707 |
+
},
|
708 |
+
'asu_campus1_v19': {
|
709 |
+
'n_rows': [18*int(321/20), 19*int(321/20)],
|
710 |
+
'n_per_row': 411,
|
711 |
+
'n_ant_bs': 128,
|
712 |
+
'n_subcarriers': 32
|
713 |
+
},
|
714 |
+
'asu_campus1_v20': {
|
715 |
+
'n_rows': [19*int(321/20), 20*int(321/20)],
|
716 |
+
'n_per_row': 411,
|
717 |
+
'n_ant_bs': 128,
|
718 |
+
'n_subcarriers': 64
|
719 |
+
},
|
720 |
+
'Boston5G_3p5_v1': {
|
721 |
+
'n_rows': [812, 812 + 1*int((1622-812)/20)],
|
722 |
+
'n_per_row': 595,
|
723 |
+
'n_ant_bs': 8,
|
724 |
+
'n_subcarriers': 32
|
725 |
+
},
|
726 |
+
'Boston5G_3p5_v2': {
|
727 |
+
'n_rows': [812 + 1*int((1622-812)/20), 812 + 2*int((1622-812)/20)],
|
728 |
+
'n_per_row': 595,
|
729 |
+
'n_ant_bs': 8,
|
730 |
+
'n_subcarriers': 64
|
731 |
+
},
|
732 |
+
'Boston5G_3p5_v3': {
|
733 |
+
'n_rows': [812 + 2*int((1622-812)/20), 812 + 3*int((1622-812)/20)],
|
734 |
+
'n_per_row': 595,
|
735 |
+
'n_ant_bs': 8,
|
736 |
+
'n_subcarriers': 128
|
737 |
+
},
|
738 |
+
'Boston5G_3p5_v4': {
|
739 |
+
'n_rows': [812 + 3*int((1622-812)/20), 812 + 4*int((1622-812)/20)],
|
740 |
+
'n_per_row': 595,
|
741 |
+
'n_ant_bs': 8,
|
742 |
+
'n_subcarriers': 256
|
743 |
+
},
|
744 |
+
'Boston5G_3p5_v5': {
|
745 |
+
'n_rows': [812 + 4*int((1622-812)/20), 812 + 5*int((1622-812)/20)],
|
746 |
+
'n_per_row': 595,
|
747 |
+
'n_ant_bs': 8,
|
748 |
+
'n_subcarriers': 512
|
749 |
+
},
|
750 |
+
'Boston5G_3p5_v6': {
|
751 |
+
'n_rows': [812 + 5*int((1622-812)/20), 812 + 6*int((1622-812)/20)],
|
752 |
+
'n_per_row': 595,
|
753 |
+
'n_ant_bs': 8,
|
754 |
+
'n_subcarriers': 1024
|
755 |
+
},
|
756 |
+
'Boston5G_3p5_v7': {
|
757 |
+
'n_rows': [812 + 6*int((1622-812)/20), 812 + 7*int((1622-812)/20)],
|
758 |
+
'n_per_row': 595,
|
759 |
+
'n_ant_bs': 16,
|
760 |
+
'n_subcarriers': 32
|
761 |
+
},
|
762 |
+
'Boston5G_3p5_v8': {
|
763 |
+
'n_rows': [812 + 7*int((1622-812)/20), 812 + 8*int((1622-812)/20)],
|
764 |
+
'n_per_row': 595,
|
765 |
+
'n_ant_bs':16,
|
766 |
+
'n_subcarriers': 64
|
767 |
+
},
|
768 |
+
'Boston5G_3p5_v9': {
|
769 |
+
'n_rows': [812 + 8*int((1622-812)/20), 812 + 9*int((1622-812)/20)],
|
770 |
+
'n_per_row': 595,
|
771 |
+
'n_ant_bs': 16,
|
772 |
+
'n_subcarriers': 128
|
773 |
+
},
|
774 |
+
'Boston5G_3p5_v10': {
|
775 |
+
'n_rows': [812 + 9*int((1622-812)/20), 812 + 10*int((1622-812)/20)],
|
776 |
+
'n_per_row': 595,
|
777 |
+
'n_ant_bs': 16,
|
778 |
+
'n_subcarriers': 256
|
779 |
+
},
|
780 |
+
'Boston5G_3p5_v11': {
|
781 |
+
'n_rows': [812 + 10*int((1622-812)/20), 812 + 11*int((1622-812)/20)],
|
782 |
+
'n_per_row': 595,
|
783 |
+
'n_ant_bs': 16,
|
784 |
+
'n_subcarriers': 512
|
785 |
+
},
|
786 |
+
'Boston5G_3p5_v12': {
|
787 |
+
'n_rows': [812 + 11*int((1622-812)/20), 812 + 12*int((1622-812)/20)],
|
788 |
+
'n_per_row': 595,
|
789 |
+
'n_ant_bs': 32,
|
790 |
+
'n_subcarriers': 32
|
791 |
+
},
|
792 |
+
'Boston5G_3p5_v13': {
|
793 |
+
'n_rows': [812 + 12*int((1622-812)/20), 812 + 13*int((1622-812)/20)],
|
794 |
+
'n_per_row': 595,
|
795 |
+
'n_ant_bs': 32,
|
796 |
+
'n_subcarriers': 64
|
797 |
+
},
|
798 |
+
'Boston5G_3p5_v14': {
|
799 |
+
'n_rows': [812 + 13*int((1622-812)/20), 812 + 14*int((1622-812)/20)],
|
800 |
+
'n_per_row': 595,
|
801 |
+
'n_ant_bs': 32,
|
802 |
+
'n_subcarriers': 128
|
803 |
+
},
|
804 |
+
'Boston5G_3p5_v15': {
|
805 |
+
'n_rows': [812 + 14*int((1622-812)/20), 812 + 15*int((1622-812)/20)],
|
806 |
+
'n_per_row': 595,
|
807 |
+
'n_ant_bs': 32,
|
808 |
+
'n_subcarriers': 256
|
809 |
+
},
|
810 |
+
'Boston5G_3p5_v16': {
|
811 |
+
'n_rows': [812 + 15*int((1622-812)/20), 812 + 16*int((1622-812)/20)],
|
812 |
+
'n_per_row': 595,
|
813 |
+
'n_ant_bs': 64,
|
814 |
+
'n_subcarriers': 32
|
815 |
+
},
|
816 |
+
'Boston5G_3p5_v17': {
|
817 |
+
'n_rows': [812 + 16*int((1622-812)/20), 812 + 17*int((1622-812)/20)],
|
818 |
+
'n_per_row': 595,
|
819 |
+
'n_ant_bs': 64,
|
820 |
+
'n_subcarriers': 64
|
821 |
+
},
|
822 |
+
'Boston5G_3p5_v18': {
|
823 |
+
'n_rows': [812 + 17*int((1622-812)/20), 812 + 18*int((1622-812)/20)],
|
824 |
+
'n_per_row': 595,
|
825 |
+
'n_ant_bs': 64,
|
826 |
+
'n_subcarriers': 128
|
827 |
+
},
|
828 |
+
'Boston5G_3p5_v19': {
|
829 |
+
'n_rows': [812 + 18*int((1622-812)/20), 812 + 19*int((1622-812)/20)],
|
830 |
+
'n_per_row': 595,
|
831 |
+
'n_ant_bs': 128,
|
832 |
+
'n_subcarriers': 32
|
833 |
+
},
|
834 |
+
'Boston5G_3p5_v20': {
|
835 |
+
'n_rows': [812 + 19*int((1622-812)/20), 812 + 20*int((1622-812)/20)],
|
836 |
+
'n_per_row': 595,
|
837 |
+
'n_ant_bs': 128,
|
838 |
+
'n_subcarriers': 64
|
839 |
+
},
|
840 |
+
'O1_3p5_v1': {
|
841 |
+
'n_rows': [0*int(3852/12), 1*int(3852/12)],
|
842 |
+
'n_per_row': 181,
|
843 |
+
'n_ant_bs': 8,
|
844 |
+
'n_subcarriers': 32
|
845 |
+
},
|
846 |
+
'O1_3p5_v2': {
|
847 |
+
'n_rows': [1*int(3852/12), 2*int(3852/12)],
|
848 |
+
'n_per_row': 181,
|
849 |
+
'n_ant_bs': 8,
|
850 |
+
'n_subcarriers': 64
|
851 |
+
},
|
852 |
+
'O1_3p5_v3': {
|
853 |
+
'n_rows': [2*int(3852/12), 3*int(3852/12)],
|
854 |
+
'n_per_row': 181,
|
855 |
+
'n_ant_bs': 8,
|
856 |
+
'n_subcarriers': 128
|
857 |
+
},
|
858 |
+
'O1_3p5_v4': {
|
859 |
+
'n_rows': [3*int(3852/12), 4*int(3852/12)],
|
860 |
+
'n_per_row': 181,
|
861 |
+
'n_ant_bs': 8,
|
862 |
+
'n_subcarriers': 256
|
863 |
+
},
|
864 |
+
'O1_3p5_v5': {
|
865 |
+
'n_rows': [4*int(3852/12), 5*int(3852/12)],
|
866 |
+
'n_per_row': 181,
|
867 |
+
'n_ant_bs': 8,
|
868 |
+
'n_subcarriers': 512
|
869 |
+
},
|
870 |
+
'O1_3p5_v6': {
|
871 |
+
'n_rows': [5*int(3852/12), 6*int(3852/12)],
|
872 |
+
'n_per_row': 181,
|
873 |
+
'n_ant_bs': 8,
|
874 |
+
'n_subcarriers': 1024
|
875 |
+
},
|
876 |
+
'O1_3p5_v7': {
|
877 |
+
'n_rows': [6*int(3852/12), 7*int(3852/12)],
|
878 |
+
'n_per_row': 181,
|
879 |
+
'n_ant_bs': 16,
|
880 |
+
'n_subcarriers': 32
|
881 |
+
},
|
882 |
+
'O1_3p5_v8': {
|
883 |
+
'n_rows': [7*int(3852/12), 8*int(3852/12)],
|
884 |
+
'n_per_row': 181,
|
885 |
+
'n_ant_bs': 16,
|
886 |
+
'n_subcarriers': 64
|
887 |
+
},
|
888 |
+
'O1_3p5_v9': {
|
889 |
+
'n_rows': [8*int(3852/12), 9*int(3852/12)],
|
890 |
+
'n_per_row': 181,
|
891 |
+
'n_ant_bs': 16,
|
892 |
+
'n_subcarriers': 128
|
893 |
+
},
|
894 |
+
'O1_3p5_v10': {
|
895 |
+
'n_rows': [9*int(3852/12), 10*int(3852/12)],
|
896 |
+
'n_per_row': 181,
|
897 |
+
'n_ant_bs': 16,
|
898 |
+
'n_subcarriers': 256
|
899 |
+
},
|
900 |
+
'O1_3p5_v11': {
|
901 |
+
'n_rows': [10*int(3852/12), 11*int(3852/12)],
|
902 |
+
'n_per_row': 181,
|
903 |
+
'n_ant_bs': 16,
|
904 |
+
'n_subcarriers': 512
|
905 |
+
},
|
906 |
+
'O1_3p5_v12': {
|
907 |
+
'n_rows': [11*int(3852/12), 12*int(3852/12)],
|
908 |
+
'n_per_row': 181,
|
909 |
+
'n_ant_bs': 32,
|
910 |
+
'n_subcarriers': 32
|
911 |
+
},
|
912 |
+
'O1_3p5_v13': {
|
913 |
+
'n_rows': [12*int(3852/12)+0*int(1351/10), 12*int(3852/12)+1*int(1351/10)],
|
914 |
+
'n_per_row': 361,
|
915 |
+
'n_ant_bs': 32,
|
916 |
+
'n_subcarriers': 64
|
917 |
+
},
|
918 |
+
'O1_3p5_v14': {
|
919 |
+
'n_rows': [12*int(3852/12)+1*int(1351/10), 12*int(3852/12)+2*int(1351/10)],
|
920 |
+
'n_per_row': 181,
|
921 |
+
'n_ant_bs': 32,
|
922 |
+
'n_subcarriers': 128
|
923 |
+
},
|
924 |
+
'O1_3p5_v15': {
|
925 |
+
'n_rows': [12*int(3852/12)+2*int(1351/10), 12*int(3852/12)+3*int(1351/10)],
|
926 |
+
'n_per_row': 181,
|
927 |
+
'n_ant_bs': 32,
|
928 |
+
'n_subcarriers': 256
|
929 |
+
},
|
930 |
+
'O1_3p5_v16': {
|
931 |
+
'n_rows': [12*int(3852/12)+3*int(1351/10), 12*int(3852/12)+4*int(1351/10)],
|
932 |
+
'n_per_row': 181,
|
933 |
+
'n_ant_bs': 64,
|
934 |
+
'n_subcarriers': 32
|
935 |
+
},
|
936 |
+
'O1_3p5_v17': {
|
937 |
+
'n_rows': [12*int(3852/12)+4*int(1351/10), 12*int(3852/12)+5*int(1351/10)],
|
938 |
+
'n_per_row': 181,
|
939 |
+
'n_ant_bs': 64,
|
940 |
+
'n_subcarriers': 64
|
941 |
+
},
|
942 |
+
'O1_3p5_v18': {
|
943 |
+
'n_rows': [12*int(3852/12)+5*int(1351/10), 12*int(3852/12)+6*int(1351/10)],
|
944 |
+
'n_per_row': 181,
|
945 |
+
'n_ant_bs': 64,
|
946 |
+
'n_subcarriers': 128
|
947 |
+
},
|
948 |
+
'O1_3p5_v19': {
|
949 |
+
'n_rows': [12*int(3852/12)+6*int(1351/10), 12*int(3852/12)+7*int(1351/10)],
|
950 |
+
'n_per_row': 181,
|
951 |
+
'n_ant_bs': 128,
|
952 |
+
'n_subcarriers': 32
|
953 |
+
},
|
954 |
+
'O1_3p5_v20': {
|
955 |
+
'n_rows': [12*int(3852/12)+7*int(1351/10), 12*int(3852/12)+8*int(1351/10)],
|
956 |
+
'n_per_row': 181,
|
957 |
+
'n_ant_bs': 128,
|
958 |
+
'n_subcarriers': 64
|
959 |
+
},
|
960 |
+
'city_0_newyork_v16x64': {
|
961 |
+
'n_rows': 109,
|
962 |
+
'n_per_row': 291,
|
963 |
+
'n_ant_bs': 16,
|
964 |
+
'n_subcarriers': 64
|
965 |
+
},
|
966 |
+
'city_1_losangeles_v16x64': {
|
967 |
+
'n_rows': 142,
|
968 |
+
'n_per_row': 201,
|
969 |
+
'n_ant_bs': 16,
|
970 |
+
'n_subcarriers': 64
|
971 |
+
},
|
972 |
+
'city_2_chicago_v16x64': {
|
973 |
+
'n_rows': 139,
|
974 |
+
'n_per_row': 200,
|
975 |
+
'n_ant_bs': 16,
|
976 |
+
'n_subcarriers': 64
|
977 |
+
},
|
978 |
+
'city_3_houston_v16x64': {
|
979 |
+
'n_rows': 154,
|
980 |
+
'n_per_row': 202,
|
981 |
+
'n_ant_bs': 16,
|
982 |
+
'n_subcarriers': 64
|
983 |
+
},
|
984 |
+
'city_4_phoenix_v16x64': {
|
985 |
+
'n_rows': 198,
|
986 |
+
'n_per_row': 214,
|
987 |
+
'n_ant_bs': 16,
|
988 |
+
'n_subcarriers': 64
|
989 |
+
},
|
990 |
+
'city_5_philadelphia_v16x64': {
|
991 |
+
'n_rows': 239,
|
992 |
+
'n_per_row': 164,
|
993 |
+
'n_ant_bs': 16,
|
994 |
+
'n_subcarriers': 64
|
995 |
+
},
|
996 |
+
'city_6_miami_v16x64': {
|
997 |
+
'n_rows': 199,
|
998 |
+
'n_per_row': 216,
|
999 |
+
'n_ant_bs': 16,
|
1000 |
+
'n_subcarriers': 64
|
1001 |
+
},
|
1002 |
+
'city_7_sandiego_v16x64': {
|
1003 |
+
'n_rows': 207,
|
1004 |
+
'n_per_row': 176,
|
1005 |
+
'n_ant_bs': 16,
|
1006 |
+
'n_subcarriers': 64
|
1007 |
+
},
|
1008 |
+
'city_8_dallas_v16x64': {
|
1009 |
+
'n_rows': 207,
|
1010 |
+
'n_per_row': 190,
|
1011 |
+
'n_ant_bs': 16,
|
1012 |
+
'n_subcarriers': 64
|
1013 |
+
},
|
1014 |
+
'city_9_sanfrancisco_v16x64': {
|
1015 |
+
'n_rows': 196,
|
1016 |
+
'n_per_row': 206,
|
1017 |
+
'n_ant_bs': 16,
|
1018 |
+
'n_subcarriers': 64
|
1019 |
+
}}
|
1020 |
+
return row_column_users
|
lwm_model.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Fri Sep 13 19:23:54 2024
|
4 |
+
|
5 |
+
This script defines the LWM model architecture.
|
6 |
+
|
7 |
+
@author: Sadjad Alikhani
|
8 |
+
"""
|
9 |
+
#%%
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import numpy as np
|
14 |
+
#%%
|
15 |
+
class LayerNormalization(nn.Module):
|
16 |
+
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.eps = eps
|
19 |
+
self.alpha = nn.Parameter(torch.ones(d_model))
|
20 |
+
self.bias = nn.Parameter(torch.zeros(d_model))
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
mean = x.mean(dim=-1, keepdim=True)
|
24 |
+
std = x.std(dim=-1, keepdim=True)
|
25 |
+
return self.alpha * (x - mean) / (std + self.eps) + self.bias
|
26 |
+
|
27 |
+
|
28 |
+
class Embedding(nn.Module):
|
29 |
+
def __init__(self, element_length, d_model, max_len=513):
|
30 |
+
super().__init__()
|
31 |
+
self.element_length = element_length
|
32 |
+
self.d_model = d_model
|
33 |
+
self.proj = nn.Linear(element_length, d_model)
|
34 |
+
self.pos_embed = nn.Embedding(max_len, d_model)
|
35 |
+
self.norm = LayerNormalization(d_model)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
seq_len = x.size(1)
|
39 |
+
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
40 |
+
pos_encodings = self.pos_embed(pos)
|
41 |
+
tok_emb = self.proj(x.float())
|
42 |
+
embedding = tok_emb + pos_encodings
|
43 |
+
return self.norm(embedding)
|
44 |
+
|
45 |
+
|
46 |
+
class ScaledDotProductAttention(nn.Module):
|
47 |
+
def __init__(self, d_k):
|
48 |
+
super().__init__()
|
49 |
+
self.d_k = d_k
|
50 |
+
|
51 |
+
def forward(self, Q, K, V):
|
52 |
+
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
|
53 |
+
attn = F.softmax(scores, dim=-1)
|
54 |
+
context = torch.matmul(attn, V)
|
55 |
+
return context, attn
|
56 |
+
|
57 |
+
|
58 |
+
class MultiHeadAttention(nn.Module):
|
59 |
+
def __init__(self, d_model, n_heads, dropout):
|
60 |
+
super().__init__()
|
61 |
+
self.d_k = d_model // n_heads
|
62 |
+
self.d_v = d_model // n_heads
|
63 |
+
self.n_heads = n_heads
|
64 |
+
self.W_Q = nn.Linear(d_model, self.d_k * n_heads)
|
65 |
+
self.W_K = nn.Linear(d_model, self.d_k * n_heads)
|
66 |
+
self.W_V = nn.Linear(d_model, self.d_v * n_heads)
|
67 |
+
self.linear = nn.Linear(n_heads * self.d_v, d_model)
|
68 |
+
self.dropout = nn.Dropout(dropout)
|
69 |
+
self.scaled_dot_attn = ScaledDotProductAttention(self.d_k)
|
70 |
+
|
71 |
+
def forward(self, Q, K, V):
|
72 |
+
residual, batch_size = Q, Q.size(0)
|
73 |
+
q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
74 |
+
k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
75 |
+
v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
|
76 |
+
|
77 |
+
context, attn = self.scaled_dot_attn(q_s, k_s, v_s)
|
78 |
+
output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
|
79 |
+
output = self.linear(output)
|
80 |
+
return residual + self.dropout(output), attn
|
81 |
+
|
82 |
+
|
83 |
+
class PoswiseFeedForwardNet(nn.Module):
|
84 |
+
def __init__(self, d_model, d_ff, dropout):
|
85 |
+
super().__init__()
|
86 |
+
self.fc1 = nn.Linear(d_model, d_ff)
|
87 |
+
self.fc2 = nn.Linear(d_ff, d_model)
|
88 |
+
self.dropout = nn.Dropout(dropout)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
return self.fc2(self.dropout(F.relu(self.fc1(x))))
|
92 |
+
|
93 |
+
|
94 |
+
class EncoderLayer(nn.Module):
|
95 |
+
def __init__(self, d_model, n_heads, d_ff, dropout):
|
96 |
+
super().__init__()
|
97 |
+
self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
98 |
+
self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff, dropout)
|
99 |
+
self.norm1 = LayerNormalization(d_model)
|
100 |
+
self.norm2 = LayerNormalization(d_model)
|
101 |
+
|
102 |
+
def forward(self, enc_inputs):
|
103 |
+
# Self-Attention with Add & Norm
|
104 |
+
attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
|
105 |
+
attn_outputs = self.norm1(enc_inputs + attn_outputs) # Add & Norm
|
106 |
+
|
107 |
+
# Feed-Forward with Add & Norm
|
108 |
+
ff_outputs = self.pos_ffn(attn_outputs)
|
109 |
+
enc_outputs = self.norm2(attn_outputs + ff_outputs) # Add & Norm
|
110 |
+
|
111 |
+
return enc_outputs, attn
|
112 |
+
|
113 |
+
|
114 |
+
class lwm(nn.Module):
|
115 |
+
def __init__(self, element_length=32, d_model=128, n_layers=12, max_len=513, n_heads=8, dropout=0.1):
|
116 |
+
super().__init__()
|
117 |
+
self.embedding = Embedding(element_length, d_model, max_len)
|
118 |
+
self.layers = nn.ModuleList(
|
119 |
+
[EncoderLayer(d_model, n_heads, d_model*4, dropout) for _ in range(n_layers)]
|
120 |
+
)
|
121 |
+
self.linear = nn.Linear(d_model, d_model)
|
122 |
+
self.norm = LayerNormalization(d_model)
|
123 |
+
|
124 |
+
embed_weight = self.embedding.proj.weight
|
125 |
+
_, n_dim = embed_weight.size()
|
126 |
+
self.decoder = nn.Linear(d_model, n_dim, bias=False)
|
127 |
+
self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
|
131 |
+
model = cls().to(device)
|
132 |
+
model.load_state_dict(torch.load(ckpt_name, map_location=device))
|
133 |
+
print(f"Model loaded successfully from {ckpt_name}")
|
134 |
+
return model
|
135 |
+
|
136 |
+
def forward(self, input_ids, masked_pos=None):
|
137 |
+
# Step 1: Embedding
|
138 |
+
output = self.embedding(input_ids)
|
139 |
+
attention_maps = []
|
140 |
+
|
141 |
+
# Step 2: Pass through Encoder Layers
|
142 |
+
for layer in self.layers:
|
143 |
+
output, attn = layer(output)
|
144 |
+
attention_maps.append(attn)
|
145 |
+
|
146 |
+
# If masked_pos is provided, perform masked token prediction
|
147 |
+
if masked_pos is not None:
|
148 |
+
masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
|
149 |
+
h_masked = torch.gather(output, 1, masked_pos)
|
150 |
+
h_masked = self.norm(F.relu(self.linear(h_masked)))
|
151 |
+
logits_lm = self.decoder(h_masked) + self.decoder_bias
|
152 |
+
return logits_lm, output, attention_maps
|
153 |
+
else:
|
154 |
+
return output, attention_maps
|
main.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Sat Dec 21 13:24:21 2024
|
4 |
+
|
5 |
+
This script pre-trains the LWM model
|
6 |
+
|
7 |
+
@author: salikha4
|
8 |
+
"""
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.utils.data import random_split
|
12 |
+
from input_preprocess import tokenizer, scenarios_list
|
13 |
+
from utils import create_dataloader, count_parameters
|
14 |
+
import numpy as np
|
15 |
+
import lwm_model
|
16 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
17 |
+
from torch.optim.lr_scheduler import LambdaLR
|
18 |
+
from torch.optim import AdamW
|
19 |
+
from train import train_lwm
|
20 |
+
import warnings
|
21 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
22 |
+
#%% SETTINGS
|
23 |
+
EPOCHS = 50
|
24 |
+
BATCH_SIZE = 128
|
25 |
+
VAL_BATCH_SIZE = 64
|
26 |
+
WARMUP_EPOCHS = 5
|
27 |
+
BASE_LR = 5e-4
|
28 |
+
N_ROWS = 4
|
29 |
+
N_COLUMNS = 4
|
30 |
+
ELEMENT_LENGTH = N_ROWS*N_COLUMNS*2
|
31 |
+
D_MODEL = 128
|
32 |
+
MAX_LEN = 513
|
33 |
+
N_LAYERS = 12
|
34 |
+
WEIGHT_DECAY = 0.05
|
35 |
+
BETA1 = 0.9
|
36 |
+
BETA2 = 0.999
|
37 |
+
MASK_PERCENT = 0.40
|
38 |
+
N_HEADS = 8
|
39 |
+
DROPOUT = 0.1
|
40 |
+
#%% GENERATE DATASET
|
41 |
+
bs_idxs = [1, 2, 3]
|
42 |
+
selected_scenario_names = scenarios_list()[:80]
|
43 |
+
preprocessed_data = tokenizer(
|
44 |
+
selected_scenario_names,
|
45 |
+
MAX_LEN,
|
46 |
+
masking_percent=MASK_PERCENT,
|
47 |
+
mask=True,
|
48 |
+
seed=42
|
49 |
+
)
|
50 |
+
#%% SPLIT DATASET
|
51 |
+
SEED = 42
|
52 |
+
torch.manual_seed(SEED)
|
53 |
+
np.random.seed(SEED)
|
54 |
+
train_ratio = 0.8
|
55 |
+
val_ratio = 0.2
|
56 |
+
train_data = {}
|
57 |
+
val_data = {}
|
58 |
+
test_data = {}
|
59 |
+
for key, samples in preprocessed_data.items():
|
60 |
+
print(f"key: {key}")
|
61 |
+
total_samples = len(samples)
|
62 |
+
train_size = int(train_ratio * total_samples)
|
63 |
+
val_size = int(val_ratio * total_samples)
|
64 |
+
test_size = total_samples - val_size - train_size
|
65 |
+
|
66 |
+
train_data[key], val_data[key], test_data[key] = random_split(
|
67 |
+
samples, [train_size, val_size, test_size]
|
68 |
+
)
|
69 |
+
train_loaders = create_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
|
70 |
+
val_loaders = create_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
|
71 |
+
#%% INITIALIZE MODEL
|
72 |
+
load_model = True
|
73 |
+
gpu_ids = [0]
|
74 |
+
device = torch.device("cuda:0")
|
75 |
+
model = lwm_model.lwm().to(device)
|
76 |
+
|
77 |
+
if load_model:
|
78 |
+
model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
|
79 |
+
state_dict = torch.load(f"models/{model_name}", map_location=device)
|
80 |
+
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
81 |
+
model.load_state_dict(new_state_dict)
|
82 |
+
|
83 |
+
model = nn.DataParallel(model, gpu_ids)
|
84 |
+
print(f"Model loaded successfully on GPU {device.index}")
|
85 |
+
|
86 |
+
n_parameters = count_parameters(model)
|
87 |
+
print(f"Number of trainable parameters: {n_parameters:,}")
|
88 |
+
#%% OPTIMIZER AND SCHEDULER
|
89 |
+
BASE_LR = 5e-5
|
90 |
+
MIN_LR = 1e-8
|
91 |
+
TOTAL_STEPS = sum(len(loader) for loader in train_loaders.values()) * EPOCHS
|
92 |
+
WARMUP_STEPS = sum(len(loader) for loader in train_loaders.values()) * WARMUP_EPOCHS
|
93 |
+
|
94 |
+
optimizer = AdamW(
|
95 |
+
model.parameters(),
|
96 |
+
lr=BASE_LR,
|
97 |
+
betas=(BETA1, BETA2),
|
98 |
+
weight_decay=WEIGHT_DECAY
|
99 |
+
)
|
100 |
+
def lr_lambda(current_step):
|
101 |
+
if current_step < WARMUP_STEPS:
|
102 |
+
# Linear warmup
|
103 |
+
return current_step / WARMUP_STEPS
|
104 |
+
else:
|
105 |
+
# Scaled cosine decay
|
106 |
+
scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
|
107 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
|
108 |
+
return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
|
109 |
+
|
110 |
+
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
|
111 |
+
#%% PRE-TRAIN THE MODEL
|
112 |
+
pretrained_model = train_lwm(
|
113 |
+
model,
|
114 |
+
train_loaders,
|
115 |
+
val_loaders,
|
116 |
+
optimizer,
|
117 |
+
scheduler,
|
118 |
+
EPOCHS,
|
119 |
+
device=device
|
120 |
+
)
|
models/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:485611f1a0f819f9c673827b8e613887b39672e97072bd7a412866b49d8dd40f
|
3 |
+
size 9960738
|
train.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Fri Dec 20 09:32:12 2024
|
4 |
+
|
5 |
+
This script contains the LWM pre-training and task-specific fine-tuning functions.
|
6 |
+
|
7 |
+
@author: Sadjad Alikhani
|
8 |
+
"""
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from tqdm import tqdm
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import os
|
14 |
+
import csv
|
15 |
+
from utils import count_parameters
|
16 |
+
import time
|
17 |
+
#%% LOSS FUNCTION
|
18 |
+
def nmse_loss(y_pred, y_true):
|
19 |
+
y_pred_flat = y_pred.view(y_pred.size(0), -1)
|
20 |
+
y_true_flat = y_true.view(y_true.size(0), -1)
|
21 |
+
mse = torch.sum((y_true_flat - y_pred_flat)**2, dim=-1)
|
22 |
+
normalization = torch.sum(y_true_flat**2, dim=-1)
|
23 |
+
return mse / normalization
|
24 |
+
#%%
|
25 |
+
def train_lwm(model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv"):
|
26 |
+
|
27 |
+
if not os.path.exists(save_dir):
|
28 |
+
os.makedirs(save_dir)
|
29 |
+
|
30 |
+
# Initialize CSV log
|
31 |
+
if not os.path.exists(log_file):
|
32 |
+
with open(log_file, mode='w', newline='') as file:
|
33 |
+
writer = csv.writer(file)
|
34 |
+
writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
|
35 |
+
|
36 |
+
train_nmse_losses = []
|
37 |
+
val_nmse_losses = []
|
38 |
+
best_val_nmse = float('inf')
|
39 |
+
|
40 |
+
for epoch in range(epochs):
|
41 |
+
model.train()
|
42 |
+
train_nmse = 0.0
|
43 |
+
train_samples = 0
|
44 |
+
|
45 |
+
# Training loop across all buckets
|
46 |
+
print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
|
47 |
+
for length, train_loader in train_loaders.items():
|
48 |
+
print(f"Processing sequences of length {length}")
|
49 |
+
with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
|
50 |
+
for batch in t:
|
51 |
+
# train_batches += 1
|
52 |
+
optimizer.zero_grad()
|
53 |
+
|
54 |
+
# Move data to device
|
55 |
+
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
|
56 |
+
|
57 |
+
# Forward pass
|
58 |
+
logits_lm, _, _ = model(input_ids, masked_pos)
|
59 |
+
|
60 |
+
# Compute NMSE
|
61 |
+
loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
|
62 |
+
loss.backward()
|
63 |
+
optimizer.step()
|
64 |
+
scheduler.step()
|
65 |
+
|
66 |
+
train_nmse += loss.item()
|
67 |
+
train_samples += input_ids.shape[0]
|
68 |
+
|
69 |
+
# Update progress bar
|
70 |
+
t.set_postfix({"nmse": train_nmse/train_samples, "lr": scheduler.get_last_lr()[0]})
|
71 |
+
|
72 |
+
# Average NMSE across training batches
|
73 |
+
train_nmse /= max(train_samples, 1)
|
74 |
+
train_nmse_losses.append(train_nmse)
|
75 |
+
|
76 |
+
if epoch % 2 == 0:
|
77 |
+
# Validation loop across all buckets
|
78 |
+
model.eval()
|
79 |
+
val_nmse = 0.0
|
80 |
+
val_samples = 0
|
81 |
+
with torch.no_grad():
|
82 |
+
print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
|
83 |
+
for length, val_loader in val_loaders.items():
|
84 |
+
print(f"Processing sequences of length {length}")
|
85 |
+
with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
|
86 |
+
for batch in t:
|
87 |
+
# val_batches += 1
|
88 |
+
|
89 |
+
# Move data to device
|
90 |
+
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
|
91 |
+
|
92 |
+
# Forward pass
|
93 |
+
logits_lm, _, _ = model(input_ids, masked_pos)
|
94 |
+
|
95 |
+
# Compute NMSE
|
96 |
+
loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
|
97 |
+
val_nmse += loss.item()
|
98 |
+
val_samples += input_ids.shape[0]
|
99 |
+
|
100 |
+
# Update progress bar
|
101 |
+
t.set_postfix({"nmse": val_nmse/val_samples})
|
102 |
+
|
103 |
+
# Average NMSE across validation batches
|
104 |
+
val_nmse /= max(val_samples, 1)
|
105 |
+
val_nmse_losses.append(val_nmse)
|
106 |
+
|
107 |
+
# Save model if validation NMSE improves
|
108 |
+
is_best_model = False
|
109 |
+
if val_nmse < best_val_nmse:
|
110 |
+
best_val_nmse = val_nmse
|
111 |
+
model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
|
112 |
+
torch.save(model.state_dict(), model_path)
|
113 |
+
print(f"Model saved: {model_path}")
|
114 |
+
is_best_model = True
|
115 |
+
|
116 |
+
# Log the results
|
117 |
+
print(f" Train NMSE: {train_nmse:.4f}")
|
118 |
+
print(f" Validation NMSE: {val_nmse:.4f}")
|
119 |
+
print(f" Learning Rate: {scheduler.get_last_lr()[0]:.6e}")
|
120 |
+
|
121 |
+
# Append to CSV log
|
122 |
+
with open(log_file, mode='a', newline='') as file:
|
123 |
+
writer = csv.writer(file)
|
124 |
+
writer.writerow([epoch + 1, train_nmse, val_nmse, scheduler.get_last_lr()[0], is_best_model])
|
125 |
+
|
126 |
+
# Plot losses after each epoch
|
127 |
+
plt.figure(figsize=(10, 6))
|
128 |
+
plt.plot(range(1, len(train_nmse_losses) + 1), train_nmse_losses, label="Train NMSE")
|
129 |
+
plt.plot(range(1, len(val_nmse_losses) + 1), val_nmse_losses, label="Validation NMSE")
|
130 |
+
plt.xlabel("Epochs")
|
131 |
+
plt.ylabel("NMSE")
|
132 |
+
plt.title("Training and Validation NMSE Loss")
|
133 |
+
plt.legend()
|
134 |
+
plt.grid(True)
|
135 |
+
plt.show()
|
136 |
+
|
137 |
+
print("Training and validation complete.")
|
138 |
+
return model
|
139 |
+
#%% FINE-TUNE
|
140 |
+
from torch.cuda.amp import GradScaler, autocast
|
141 |
+
|
142 |
+
# Define the ClassificationHead
|
143 |
+
class ClassificationHead(nn.Module):
|
144 |
+
def __init__(self, input_dim, num_classes):
|
145 |
+
super().__init__()
|
146 |
+
self.fc = nn.Linear(input_dim, num_classes)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
return self.fc(x)
|
150 |
+
|
151 |
+
|
152 |
+
# Define the RegressionHead
|
153 |
+
class RegressionHead(nn.Module):
|
154 |
+
def __init__(self, input_dim):
|
155 |
+
super().__init__()
|
156 |
+
self.fc = nn.Linear(input_dim, 1)
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
return self.fc(x)
|
160 |
+
|
161 |
+
class CustomClassificationHead(nn.Module):
|
162 |
+
def __init__(self, input_dim, num_classes):
|
163 |
+
|
164 |
+
super().__init__()
|
165 |
+
self.classifier = nn.Sequential(
|
166 |
+
nn.Linear(input_dim, 512),
|
167 |
+
nn.BatchNorm1d(512),
|
168 |
+
nn.ReLU(),
|
169 |
+
nn.Dropout(0.1),
|
170 |
+
nn.Linear(512, 256),
|
171 |
+
nn.BatchNorm1d(256),
|
172 |
+
nn.ReLU(),
|
173 |
+
nn.Dropout(0.1),
|
174 |
+
nn.Linear(256, 128),
|
175 |
+
nn.BatchNorm1d(128),
|
176 |
+
nn.ReLU(),
|
177 |
+
# nn.Dropout(0.1),
|
178 |
+
nn.Linear(128, num_classes)
|
179 |
+
)
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
return self.classifier(x)
|
183 |
+
|
184 |
+
class CustomRegressionHead(nn.Module):
|
185 |
+
def __init__(self, input_dim, output_dim):
|
186 |
+
|
187 |
+
super().__init__()
|
188 |
+
self.regressor = nn.Sequential(
|
189 |
+
nn.Linear(input_dim, 512),
|
190 |
+
nn.BatchNorm1d(512),
|
191 |
+
nn.ReLU(),
|
192 |
+
nn.Dropout(0.1),
|
193 |
+
nn.Linear(512, 256),
|
194 |
+
nn.BatchNorm1d(256),
|
195 |
+
nn.ReLU(),
|
196 |
+
nn.Dropout(0.1),
|
197 |
+
nn.Linear(256, output_dim)
|
198 |
+
)
|
199 |
+
|
200 |
+
def forward(self, x):
|
201 |
+
return self.regressor(x)
|
202 |
+
|
203 |
+
|
204 |
+
def custom_heads(input_dim, num_classes=None, output_dim=None, task_type="classification"):
|
205 |
+
"""
|
206 |
+
Creates a custom head for classification or regression tasks.
|
207 |
+
Users should modify the class implementations for further customization.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
input_dim (int): Input dimension of the head.
|
211 |
+
num_classes (int): Number of classes for classification tasks. Ignored for regression.
|
212 |
+
task_type (str): "classification" or "regression".
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
nn.Module: Custom head for the specified task.
|
216 |
+
"""
|
217 |
+
if task_type == "classification":
|
218 |
+
if num_classes is None:
|
219 |
+
raise ValueError("num_classes must be specified for classification tasks.")
|
220 |
+
return CustomClassificationHead(input_dim=input_dim, num_classes=num_classes)
|
221 |
+
elif task_type == "regression":
|
222 |
+
return CustomRegressionHead(input_dim=input_dim, output_dim=output_dim)
|
223 |
+
else:
|
224 |
+
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
|
225 |
+
#%%
|
226 |
+
# Fine-tuning wrapper for the base model
|
227 |
+
class FineTuningWrapper(nn.Module):
|
228 |
+
def __init__(self, model, task_head, fine_tune_layers="full"):
|
229 |
+
super().__init__()
|
230 |
+
self.model = model
|
231 |
+
self.task_head = task_head
|
232 |
+
|
233 |
+
# Freeze all layers initially
|
234 |
+
for param in self.model.parameters():
|
235 |
+
param.requires_grad = False
|
236 |
+
|
237 |
+
# Handle fine-tuning layers
|
238 |
+
if fine_tune_layers is not None:
|
239 |
+
if fine_tune_layers == "full":
|
240 |
+
# Unfreeze all layers if "all" is specified
|
241 |
+
for param in self.model.parameters():
|
242 |
+
param.requires_grad = True
|
243 |
+
else:
|
244 |
+
# Get a list of all available layer names in the model
|
245 |
+
available_layers = [name for name, _ in self.model.named_parameters()]
|
246 |
+
|
247 |
+
# Validate that specified layers exist in the model
|
248 |
+
for layer in fine_tune_layers:
|
249 |
+
if not any(layer in lname for lname in available_layers):
|
250 |
+
raise ValueError(
|
251 |
+
f"Layer '{layer}' not found in the model. "
|
252 |
+
f"Available layers: {available_layers}"
|
253 |
+
)
|
254 |
+
|
255 |
+
# Unfreeze only the specified layers
|
256 |
+
for name, param in self.model.named_parameters():
|
257 |
+
if any(layer in name for layer in fine_tune_layers):
|
258 |
+
param.requires_grad = True
|
259 |
+
|
260 |
+
def forward(self, x, input_type="cls_emb"):
|
261 |
+
if input_type == "raw":
|
262 |
+
task_input = x.view(x.size(0), -1)
|
263 |
+
else:
|
264 |
+
embeddings, attn_maps = self.model(x) # Get embeddings from the base model
|
265 |
+
if input_type == "cls_emb":
|
266 |
+
task_input = embeddings[:, 0, :] # CLS token
|
267 |
+
elif input_type == "chs_emb":
|
268 |
+
chs_emb = embeddings[:, 1:, :]
|
269 |
+
task_input = chs_emb.view(chs_emb.size(0), -1) # embeddings.mean(dim=1) # Mean pooling over channel embeddings
|
270 |
+
|
271 |
+
return self.task_head(task_input), 0 if input_type=="raw" else attn_maps
|
272 |
+
#%%
|
273 |
+
# Fine-tuning function
|
274 |
+
from sklearn.metrics import f1_score
|
275 |
+
def finetune(
|
276 |
+
base_model,
|
277 |
+
train_loader,
|
278 |
+
val_loader=None,
|
279 |
+
task_type="classification",
|
280 |
+
input_type="cls_emb",
|
281 |
+
num_classes=None,
|
282 |
+
output_dim=None,
|
283 |
+
use_custom_head=False,
|
284 |
+
fine_tune_layers=None,
|
285 |
+
optimizer_config=None,
|
286 |
+
criterion=None,
|
287 |
+
epochs=10,
|
288 |
+
device="cuda",
|
289 |
+
task="Beam Prediction"
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
Configures and fine-tunes the base model with user-defined settings, saving results and models.
|
293 |
+
"""
|
294 |
+
# Create results folder
|
295 |
+
time_now = f"{time.time():.0f}"
|
296 |
+
results_folder = f"results/{task}/{time_now}"
|
297 |
+
os.makedirs(results_folder, exist_ok=True)
|
298 |
+
log_file = os.path.join(results_folder, "training_log.csv")
|
299 |
+
|
300 |
+
# Initialize the CSV log
|
301 |
+
with open(log_file, mode='w', newline='') as file:
|
302 |
+
writer = csv.writer(file)
|
303 |
+
writer.writerow(["Task", "Input", "Epoch", "Train Loss", "Validation Loss", "F1-Score (Classification)", "Learning Rate", "Time"])
|
304 |
+
|
305 |
+
for batch in val_loader:
|
306 |
+
input_data, targets = batch[0].to(device), batch[1].to(device)
|
307 |
+
break
|
308 |
+
|
309 |
+
if input_type == "cls_emb":
|
310 |
+
n_patches = 1
|
311 |
+
patch_size = 128
|
312 |
+
elif input_type == "channel_emb":
|
313 |
+
n_patches = input_data.shape[1]-1
|
314 |
+
patch_size = 128
|
315 |
+
elif input_type == "raw":
|
316 |
+
n_patches = input_data.shape[1]
|
317 |
+
patch_size = 32
|
318 |
+
# patch_size = 1
|
319 |
+
|
320 |
+
if use_custom_head:
|
321 |
+
custom_head = custom_heads(input_dim=n_patches*patch_size,
|
322 |
+
num_classes=num_classes,
|
323 |
+
output_dim=output_dim,
|
324 |
+
task_type=task_type)
|
325 |
+
|
326 |
+
# Handle DataParallel models
|
327 |
+
if isinstance(base_model, nn.DataParallel):
|
328 |
+
base_model = base_model.module
|
329 |
+
|
330 |
+
# Set up the task-specific head
|
331 |
+
if use_custom_head:
|
332 |
+
task_head = custom_head
|
333 |
+
elif task_type == "classification":
|
334 |
+
if num_classes is None:
|
335 |
+
raise ValueError("num_classes must be specified for classification tasks.")
|
336 |
+
task_head = ClassificationHead(input_dim=n_patches*patch_size, num_classes=num_classes) # input_dim=base_model.embedding.d_model
|
337 |
+
elif task_type == "regression":
|
338 |
+
task_head = RegressionHead(input_dim=n_patches*patch_size) # input_dim=base_model.embedding.d_model
|
339 |
+
else:
|
340 |
+
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
|
341 |
+
|
342 |
+
# Wrap the model with the fine-tuning head
|
343 |
+
wrapper = FineTuningWrapper(base_model, task_head, fine_tune_layers=fine_tune_layers)
|
344 |
+
wrapper = wrapper.to(device)
|
345 |
+
|
346 |
+
print(f'Number of head parameters: {count_parameters(wrapper)}')
|
347 |
+
|
348 |
+
# Set default optimizer config if not provided
|
349 |
+
if optimizer_config is None:
|
350 |
+
optimizer_config = {"lr": 1e-4}
|
351 |
+
# Set up the optimizer
|
352 |
+
optimizer = torch.optim.Adam(wrapper.parameters(), **optimizer_config)
|
353 |
+
# Set up the scheduler for learning rate decay
|
354 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2) # Example: Reduce LR by 10x every 10 epochs
|
355 |
+
|
356 |
+
# Set up the loss criterion
|
357 |
+
if criterion is None:
|
358 |
+
criterion = nn.CrossEntropyLoss() if task_type == "classification" else nn.MSELoss()
|
359 |
+
|
360 |
+
scaler = GradScaler()
|
361 |
+
train_losses, val_losses, f1_scores = [], [], []
|
362 |
+
best_val_loss = float("inf")
|
363 |
+
best_model_path = None
|
364 |
+
|
365 |
+
for epoch in range(epochs):
|
366 |
+
# Training loop
|
367 |
+
wrapper.train()
|
368 |
+
epoch_loss = 0.0
|
369 |
+
|
370 |
+
with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as progress_bar:
|
371 |
+
for batch in progress_bar:
|
372 |
+
input_data, targets = batch[0].to(device), batch[1].to(device)
|
373 |
+
optimizer.zero_grad()
|
374 |
+
|
375 |
+
with autocast():
|
376 |
+
outputs, attn_maps = wrapper(input_data, input_type=input_type)
|
377 |
+
loss = criterion(outputs, targets)
|
378 |
+
|
379 |
+
scaler.scale(loss).backward()
|
380 |
+
scaler.step(optimizer)
|
381 |
+
scaler.update()
|
382 |
+
|
383 |
+
epoch_loss += loss.item()
|
384 |
+
progress_bar.set_postfix({"Loss": loss.item()})
|
385 |
+
|
386 |
+
avg_train_loss = epoch_loss / len(train_loader)
|
387 |
+
train_losses.append(avg_train_loss)
|
388 |
+
|
389 |
+
# Validation loop
|
390 |
+
if val_loader:
|
391 |
+
wrapper.eval()
|
392 |
+
val_loss = 0.0
|
393 |
+
all_preds, all_targets = [], []
|
394 |
+
|
395 |
+
with torch.no_grad():
|
396 |
+
for batch in val_loader:
|
397 |
+
input_data, targets = batch[0].to(device), batch[1].to(device)
|
398 |
+
with autocast():
|
399 |
+
outputs, _ = wrapper(input_data, input_type=input_type)
|
400 |
+
loss = criterion(outputs, targets)
|
401 |
+
|
402 |
+
val_loss += loss.item()
|
403 |
+
|
404 |
+
if task_type == "classification":
|
405 |
+
preds = torch.argmax(outputs, dim=1).cpu().numpy()
|
406 |
+
all_preds.extend(preds)
|
407 |
+
all_targets.extend(targets.cpu().numpy())
|
408 |
+
|
409 |
+
avg_val_loss = val_loss / len(val_loader)
|
410 |
+
val_losses.append(avg_val_loss)
|
411 |
+
|
412 |
+
time_now = f"{time.time():.0f}"
|
413 |
+
# Save the best model
|
414 |
+
if avg_val_loss < best_val_loss:
|
415 |
+
best_val_loss = avg_val_loss
|
416 |
+
best_model_path = os.path.join(results_folder, f"{input_type}_epoch{epoch+1}_valLoss{avg_val_loss:.4f}_{time_now}.pth")
|
417 |
+
torch.save(wrapper.state_dict(), best_model_path)
|
418 |
+
print(f"Model saved at {best_model_path} with validation loss: {best_val_loss:.4f}")
|
419 |
+
|
420 |
+
# Compute F1-score for classification tasks
|
421 |
+
f1 = None
|
422 |
+
if task_type == "classification":
|
423 |
+
f1 = f1_score(all_targets, all_preds, average="macro")
|
424 |
+
print(f"Epoch {epoch + 1}, Validation F1-Score: {f1:.4f}")
|
425 |
+
f1_scores.append(f1)
|
426 |
+
|
427 |
+
scheduler.step()
|
428 |
+
|
429 |
+
# Log results
|
430 |
+
with open(log_file, mode='a', newline='') as file:
|
431 |
+
writer = csv.writer(file)
|
432 |
+
writer.writerow([task, input_type, epoch + 1, avg_train_loss, avg_val_loss, f1 if f1 is not None else "-", scheduler.get_last_lr()[0], f"{time_now}"])
|
433 |
+
|
434 |
+
# Plot training and validation losses
|
435 |
+
plt.figure(figsize=(10, 6))
|
436 |
+
plt.plot(range(1, epochs + 1), train_losses, label="Training Loss")
|
437 |
+
plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss", linestyle="--")
|
438 |
+
plt.xlabel("Epochs")
|
439 |
+
plt.ylabel("Loss")
|
440 |
+
plt.title("Training and Validation Loss")
|
441 |
+
plt.legend()
|
442 |
+
plt.grid(True)
|
443 |
+
# plt.savefig(os.path.join(results_folder, "loss_curve.png"))
|
444 |
+
plt.show()
|
445 |
+
|
446 |
+
return wrapper, best_model_path, train_losses, val_losses, f1_scores if task_type == "classification" else 0, attn_maps
|
utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
#%%
|
5 |
+
def create_dataloader(grouped_data, batch_size, shuffle):
|
6 |
+
|
7 |
+
dataloaders = {}
|
8 |
+
|
9 |
+
for seq_length, group in grouped_data.items():
|
10 |
+
|
11 |
+
print(f"dataloader in progress ...\nkey: {seq_length}")
|
12 |
+
|
13 |
+
## Uncomment the following line if you run out of memory during pre-training
|
14 |
+
# batch_size = batch_size // 8 if seq_length >= 5 else batch_size
|
15 |
+
|
16 |
+
# Unpack samples for the current group
|
17 |
+
input_ids, masked_tokens, masked_pos = zip(*group)
|
18 |
+
|
19 |
+
# Convert to tensors
|
20 |
+
input_ids_tensor = torch.tensor(input_ids, dtype=torch.float32)
|
21 |
+
masked_tokens_tensor = torch.tensor(masked_tokens, dtype=torch.float32)
|
22 |
+
masked_pos_tensor = torch.tensor(masked_pos, dtype=torch.long)
|
23 |
+
|
24 |
+
# Create TensorDataset and DataLoader
|
25 |
+
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
26 |
+
dataloaders[seq_length] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)
|
27 |
+
|
28 |
+
return dataloaders
|
29 |
+
#%%
|
30 |
+
def count_parameters(model):
|
31 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
32 |
+
#%%
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
from sklearn.decomposition import PCA
|
35 |
+
from sklearn.manifold import TSNE
|
36 |
+
import umap
|
37 |
+
|
38 |
+
def visualize_embeddings(embeddings, labels, method="pca", label=None):
|
39 |
+
"""
|
40 |
+
Visualize embeddings using PCA, UMAP, or t-SNE with color-coded labels.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
embeddings (torch.Tensor or np.ndarray): Embeddings to visualize, shape (n_samples, n_features).
|
44 |
+
labels (torch.Tensor or np.ndarray): Class labels corresponding to embeddings, shape (n_samples,).
|
45 |
+
method (str): Dimensionality reduction method ('pca', 'umap', or 'tsne').
|
46 |
+
title (str): Title of the plot.
|
47 |
+
"""
|
48 |
+
# Convert to numpy if input is a torch.Tensor
|
49 |
+
if isinstance(embeddings, torch.Tensor):
|
50 |
+
embeddings = embeddings.cpu().numpy()
|
51 |
+
if isinstance(labels, torch.Tensor):
|
52 |
+
labels = labels.cpu().numpy()
|
53 |
+
|
54 |
+
# Apply the selected dimensionality reduction method
|
55 |
+
if method.lower() == "pca":
|
56 |
+
reducer = PCA(n_components=2)
|
57 |
+
elif method.lower() == "umap":
|
58 |
+
reducer = umap.UMAP(n_components=2, n_neighbors=16, random_state=42)
|
59 |
+
elif method.lower() == "tsne":
|
60 |
+
reducer = TSNE(n_components=2, random_state=42, init="random")
|
61 |
+
else:
|
62 |
+
raise ValueError("Invalid method. Choose from 'pca', 'umap', or 'tsne'.")
|
63 |
+
|
64 |
+
reduced_embeddings = reducer.fit_transform(embeddings)
|
65 |
+
|
66 |
+
# Create a scatter plot with color-coding based on labels
|
67 |
+
plt.figure(figsize=(10, 8))
|
68 |
+
num_classes = len(np.unique(labels))
|
69 |
+
colors = plt.cm.get_cmap("tab10", num_classes)
|
70 |
+
|
71 |
+
for class_idx in range(num_classes):
|
72 |
+
class_points = reduced_embeddings[labels == class_idx]
|
73 |
+
plt.scatter(
|
74 |
+
class_points[:, 0], class_points[:, 1],
|
75 |
+
label=f"Class {class_idx}",
|
76 |
+
alpha=0.6
|
77 |
+
)
|
78 |
+
|
79 |
+
# Customize the plot
|
80 |
+
plt.title(f"{label} ({method.upper()})")
|
81 |
+
plt.xlabel("Component 1")
|
82 |
+
plt.ylabel("Component 2")
|
83 |
+
plt.legend()
|
84 |
+
plt.show()
|
85 |
+
#%%
|
86 |
+
def generate_gaussian_noise(data, snr_db):
|
87 |
+
"""
|
88 |
+
Generate Gaussian noise given an SNR and apply it to the data.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
data (torch.Tensor): Input data tensor of shape (n_samples, seq_len, feature_dim).
|
92 |
+
snr_db (float): Signal-to-Noise Ratio in decibels (dB).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
torch.Tensor: Data with Gaussian noise applied.
|
96 |
+
"""
|
97 |
+
# Separate the input data to exclude the first channel
|
98 |
+
a = data[:, 1:, :] # Shape: (n_samples, seq_len-1, feature_dim)
|
99 |
+
flat_data = a.view(a.size(0), -1) # Flatten data to calculate power
|
100 |
+
signal_power = torch.mean(flat_data**2, dim=1, keepdim=True) # Shape: (n_samples, 1)
|
101 |
+
snr_linear = 10 ** (snr_db / 10)
|
102 |
+
noise_power = signal_power / snr_linear
|
103 |
+
noise = torch.randn_like(flat_data) * torch.sqrt(noise_power)
|
104 |
+
noise = noise.view_as(a)
|
105 |
+
noise = torch.cat((torch.zeros_like(data[:, :1, :]), noise), dim=1) # Add zero noise for the first channel
|
106 |
+
|
107 |
+
return noise
|
108 |
+
#%%
|
109 |
+
def plot_coverage(rxs, cov_map, dpi=200, figsize=(6,4), cbar_title=None, title=False,
|
110 |
+
scat_sz=.5, tx_pos=None, tx_ori=None, legend=False, lims=None,
|
111 |
+
proj_3D=False, equal_aspect=False, tight=True, cmap='tab20'):
|
112 |
+
|
113 |
+
plt_params = {'cmap': cmap}
|
114 |
+
if lims:
|
115 |
+
plt_params['vmin'], plt_params['vmax'] = lims[0], lims[1]
|
116 |
+
|
117 |
+
n = 3 if proj_3D else 2 # n coordinates to consider 2 = xy | 3 = xyz
|
118 |
+
|
119 |
+
xyz = {'x': rxs[:,0], 'y': rxs[:,1]}
|
120 |
+
if proj_3D:
|
121 |
+
xyz['zs'] = rxs[:,2]
|
122 |
+
|
123 |
+
fig, ax = plt.subplots(dpi=dpi, figsize=figsize,
|
124 |
+
subplot_kw={'projection': '3d'} if proj_3D else {})
|
125 |
+
|
126 |
+
im = plt.scatter(**xyz, c=cov_map, s=scat_sz, marker='s', **plt_params)
|
127 |
+
|
128 |
+
cbar = plt.colorbar(im, label='' if not cbar_title else cbar_title)
|
129 |
+
|
130 |
+
plt.xlabel('x (m)')
|
131 |
+
plt.ylabel('y (m)')
|
132 |
+
|
133 |
+
# TX position
|
134 |
+
if tx_pos is not None:
|
135 |
+
ax.scatter(*tx_pos[:n], marker='P', c='r', label='TX')
|
136 |
+
|
137 |
+
# TX orientation
|
138 |
+
if tx_ori is not None and tx_pos is not None: # ori = [azi, el]
|
139 |
+
# positive azimuths point left (like positive angles in a unit circle)
|
140 |
+
# positive elevations point up
|
141 |
+
r = 30 # ref size of pointing direction
|
142 |
+
tx_lookat = np.copy(tx_pos)
|
143 |
+
tx_lookat[:2] += r * np.array([np.cos(tx_ori[2]), np.sin(tx_ori[2])]) # azimuth
|
144 |
+
tx_lookat[2] += r * np.sin(tx_ori[1]) # elevation
|
145 |
+
|
146 |
+
line_components = [[tx_pos[i], tx_lookat[i]] for i in range(n)]
|
147 |
+
line = {key:val for key,val in zip(['xs', 'ys', 'zs'], line_components)}
|
148 |
+
if n == 2:
|
149 |
+
ax.plot(line_components[0], line_components[1], c='k', alpha=.5, zorder=3)
|
150 |
+
else:
|
151 |
+
ax.plot(**line, c='k', alpha=.5, zorder=3)
|
152 |
+
|
153 |
+
if title:
|
154 |
+
ax.set_title(title)
|
155 |
+
|
156 |
+
if legend:
|
157 |
+
plt.legend(loc='upper center', ncols=10, framealpha=.5)
|
158 |
+
|
159 |
+
if tight:
|
160 |
+
s = 1
|
161 |
+
mins, maxs = np.min(rxs, axis=0)-s, np.max(rxs, axis=0)+s
|
162 |
+
if not proj_3D:
|
163 |
+
plt.xlim([mins[0], maxs[0]])
|
164 |
+
plt.ylim([mins[1], maxs[1]])
|
165 |
+
else:
|
166 |
+
ax.axes.set_xlim3d([mins[0], maxs[0]])
|
167 |
+
ax.axes.set_ylim3d([mins[1], maxs[1]])
|
168 |
+
if tx_pos is None:
|
169 |
+
ax.axes.set_zlim3d([mins[2], maxs[2]])
|
170 |
+
else:
|
171 |
+
ax.axes.set_zlim3d([np.min([mins[2], tx_pos[2]]),
|
172 |
+
np.max([mins[2], tx_pos[2]])])
|
173 |
+
|
174 |
+
if equal_aspect and not proj_3D: # disrups the plot
|
175 |
+
plt.axis('scaled')
|
176 |
+
|
177 |
+
return fig, ax, cbar
|
178 |
+
#%%
|
179 |
+
def prepare_loaders(
|
180 |
+
preprocessed_data,
|
181 |
+
labels=None,
|
182 |
+
selected_patches_idxs=None,
|
183 |
+
input_type="raw",
|
184 |
+
task_type="classification",
|
185 |
+
feature_selection=False,
|
186 |
+
train_ratio=0.8,
|
187 |
+
batch_size=64,
|
188 |
+
seed=42 # Default seed for reproducibility
|
189 |
+
):
|
190 |
+
"""
|
191 |
+
Prepares datasets and data loaders for training and validation.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
preprocessed_data (torch.Tensor): The input data, either raw or preprocessed.
|
195 |
+
labels (torch.Tensor, optional): The labels for classification tasks.
|
196 |
+
selected_patches_idxs (torch.Tensor, optional): Indices of selected patches for feature selection.
|
197 |
+
input_type (str): "raw" or "processed" to specify input data type.
|
198 |
+
task_type (str): "classification" or "regression".
|
199 |
+
feature_selection (bool): Whether to perform feature selection based on selected_patches_idxs.
|
200 |
+
train_ratio (float): Proportion of data to use for training (remaining for validation).
|
201 |
+
batch_size (int): Batch size for data loaders.
|
202 |
+
seed (int): Random seed for reproducibility.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
tuple: (train_loader, val_loader)
|
206 |
+
"""
|
207 |
+
# Set random seed for reproducibility
|
208 |
+
torch.manual_seed(seed)
|
209 |
+
|
210 |
+
# Prepare samples
|
211 |
+
if input_type == "raw":
|
212 |
+
if feature_selection and selected_patches_idxs is not None:
|
213 |
+
batch_indices = torch.arange(preprocessed_data.size(0)).unsqueeze(1) # Shape: [batch_size, 1]
|
214 |
+
samples = torch.tensor(preprocessed_data[batch_indices, selected_patches_idxs], dtype=torch.float32)
|
215 |
+
else:
|
216 |
+
samples = torch.tensor(preprocessed_data[:, 1:], dtype=torch.float32) # raw_chs
|
217 |
+
else:
|
218 |
+
samples = torch.tensor(preprocessed_data, dtype=torch.float32)
|
219 |
+
|
220 |
+
# Prepare dataset
|
221 |
+
if task_type == "classification":
|
222 |
+
if labels is None:
|
223 |
+
raise ValueError("Labels are required for classification tasks.")
|
224 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
225 |
+
dataset = TensorDataset(samples, labels)
|
226 |
+
target = 0 # REVISE if needed
|
227 |
+
elif task_type == "regression":
|
228 |
+
target = samples[:, 1:, :].view(samples.size(0), -1) # Reshape for regression targets
|
229 |
+
dataset = TensorDataset(samples, target)
|
230 |
+
else:
|
231 |
+
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
|
232 |
+
|
233 |
+
# Set random seed for reproducibility
|
234 |
+
generator = torch.Generator().manual_seed(seed)
|
235 |
+
|
236 |
+
# Split dataset into training and validation
|
237 |
+
n_samples = len(dataset)
|
238 |
+
train_size = int(train_ratio * n_samples)
|
239 |
+
val_size = n_samples - train_size
|
240 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
241 |
+
|
242 |
+
# Create DataLoaders
|
243 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
|
244 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
245 |
+
|
246 |
+
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")
|
247 |
+
return train_loader, val_loader, samples, target
|