wi-lab commited on
Commit
8920c6e
·
verified ·
1 Parent(s): 2364aca

Upload the pre-trained model and pre-training, inference, downstream, and utility scripts

Browse files
Files changed (9) hide show
  1. .gitignore +2 -0
  2. downstream.py +146 -0
  3. inference.py +52 -0
  4. input_preprocess.py +1020 -0
  5. lwm_model.py +154 -0
  6. main.py +120 -0
  7. models/model.pth +3 -0
  8. train.py +446 -0
  9. utils.py +247 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__*
2
+ /images
downstream.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Jan 10 11:11:58 2025
4
+
5
+ This script evaluates downstream task performance by comparing models trained
6
+ on raw channel representations versus those trained on LWM embeddings.
7
+
8
+ @author: Sadjad Alikhani
9
+ """
10
+ #%% IMPORT PACKAGES & MODULES
11
+ from input_preprocess import tokenizer, scenarios_list
12
+ from inference import lwm_inference
13
+ from utils import prepare_loaders
14
+ from train import finetune
15
+ import lwm_model
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import warnings
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ #%% DOWNSTERAM DATA GENERATION
23
+ n_beams = 16
24
+ task = ['Beam Prediction', 'LoS/NLoS Classification'][1]
25
+ task_type = ["classification", "regression"][0]
26
+ visualization_method = ["pca", "umap", "tsne"][2]
27
+ input_types = ["cls_emb", "channel_emb", "raw"]
28
+ train_ratios = [.001, .01, .05, .1, .25, .5, .8]
29
+ fine_tuning_status = [None, ["layers.8", "layers.9", "layers.10", "layers.11"], "full"]
30
+ selected_scenario_names = [scenarios_list()[18]]
31
+ preprocessed_data, labels, raw_chs = tokenizer(
32
+ selected_scenario_names,
33
+ bs_idxs=[3],
34
+ load_data=False,
35
+ task=task,
36
+ n_beams=n_beams)
37
+ #%% LOAD THE MODEL
38
+ gpu_ids = [0]
39
+ device = torch.device("cuda:0")
40
+ model = lwm_model.lwm().to(device)
41
+
42
+ model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
43
+ state_dict = torch.load(f"models/{model_name}", map_location=device)
44
+ new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
45
+ model.load_state_dict(new_state_dict)
46
+
47
+ model = nn.DataParallel(model, gpu_ids)
48
+ print(f"Model loaded successfully on GPU {device.index}")
49
+ #%% 2D EMBEDDING SPACE VISUALIZATIONN BEFORE FINE-TUNING
50
+ chs = lwm_inference(
51
+ model,
52
+ preprocessed_data,
53
+ input_type="cls_emb",
54
+ device=device,
55
+ batch_size=64,
56
+ visualization=False,
57
+ labels=labels,
58
+ visualization_method=visualization_method)
59
+ #%% FINE-TUNE
60
+ results = np.zeros((len(fine_tuning_status), len(input_types), len(train_ratios)))
61
+ for fine_tuning_stat_idx, fine_tuning_stat in enumerate(fine_tuning_status):
62
+ for input_type_idx, input_type in enumerate(input_types):
63
+
64
+ if input_type == "raw" and fine_tuning_stat is not None:
65
+ continue
66
+
67
+ selected_patches_idxs = None
68
+ for train_ratio_idx, train_ratio in enumerate(train_ratios):
69
+
70
+ print(f"\nfine-tuning status: {fine_tuning_stat}")
71
+ print(f"input type: {input_type}")
72
+ print(f"train ratio: {train_ratio}\n")
73
+
74
+ # PREPARE LOADERS
75
+ train_loader, val_loader, samples, target = prepare_loaders(
76
+ preprocessed_data=preprocessed_data,
77
+ labels=labels,
78
+ selected_patches_idxs=selected_patches_idxs,
79
+ input_type=input_type,
80
+ task_type=task_type,
81
+ train_ratio=train_ratio,
82
+ batch_size=128,
83
+ seed=42
84
+ )
85
+
86
+ # FINE-TUNE LWM
87
+ fine_tuned_model, best_model_path, train_losses, val_losses, f1_scores, attn_maps_ft = finetune(
88
+ base_model=model,
89
+ train_loader=train_loader,
90
+ val_loader=val_loader,
91
+ task_type=task_type,
92
+ input_type=input_type,
93
+ num_classes=n_beams if task=='Beam Prediction' else 2 if task=='LoS/NLoS Classification' else None,
94
+ output_dim=target.shape[-1] if task_type =='regression' else None,
95
+ use_custom_head=True,
96
+ fine_tune_layers=fine_tuning_stat,
97
+ optimizer_config={"lr": 1e-3},
98
+ epochs=15,
99
+ device=device,
100
+ task=task
101
+ )
102
+
103
+ results[fine_tuning_stat_idx][input_type_idx][train_ratio_idx] = f1_scores[-1]
104
+
105
+ markers = ['o', 's', 'D']
106
+ labels = ['CLS Emb', 'CHS Emb', 'Raw']
107
+ fine_tuning_status_labels = ['No FT', 'Partial FT', 'Full FT']
108
+ line_styles = ['-', '--', ':']
109
+ colors = plt.cm.viridis(np.linspace(0, 0.8, len(labels)))
110
+ plt.figure(figsize=(12, 8), dpi=500)
111
+ for ft_idx, (ft_status_label, line_style) in enumerate(zip(fine_tuning_status_labels, line_styles)):
112
+ for idx, (marker, label, color) in enumerate(zip(markers, labels, colors)):
113
+ # For "Raw Channels," only plot "No Fine-Tuning" case
114
+ if label == "Raw" and ft_status_label != "No FT":
115
+ continue
116
+ # Simplify label for "Raw Channels" without fine-tuning
117
+ plot_label = label if label != "Raw Channels" or ft_status_label != "No Fine-Tuning" else "Raw Channels"
118
+ plt.plot(
119
+ train_ratios,
120
+ results[ft_idx, idx],
121
+ marker=marker,
122
+ linestyle=line_style,
123
+ label=f"{plot_label} ({ft_status_label})" if label != "Raw Channels" else plot_label,
124
+ color=color,
125
+ linewidth=3,
126
+ markersize=9
127
+ )
128
+ plt.xscale('log')
129
+ plt.xlabel("Train Ratio", fontsize=20)
130
+ plt.ylabel("F1-Score", fontsize=20)
131
+ plt.legend(fontsize=17, loc="best")
132
+ plt.grid(True, linestyle="--", alpha=0.7)
133
+ plt.xticks(fontsize=17)
134
+ plt.yticks(fontsize=17)
135
+ plt.tight_layout()
136
+ plt.show()
137
+ #%% 2D EMBEDDING SPACE VISUALIZATIONN AFTER FINE-TUNING
138
+ chs = lwm_inference(
139
+ fine_tuned_model.model,
140
+ preprocessed_data,
141
+ input_type="cls_emb",
142
+ device=device,
143
+ batch_size=64,
144
+ visualization=False,
145
+ labels=labels,
146
+ visualization_method=visualization_method)
inference.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 18:27:17 2024
4
+
5
+ This scripts performs the LWM inference on raw channel representations.
6
+
7
+ @author: Sadjad Alikhani
8
+ """
9
+ import torch
10
+ from torch.utils.data import DataLoader, TensorDataset
11
+ from utils import visualize_embeddings
12
+ from tqdm import tqdm
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+ #%%
16
+ def lwm_inference(model, data, input_type="cls_emb", device="cpu", batch_size=64, visualization=False, labels=None, visualization_method="t-sne"):
17
+
18
+ if input_type == "raw":
19
+ output_total = data
20
+ else:
21
+ dataset = TensorDataset(data)
22
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
23
+
24
+ embeddings = []
25
+ model.eval()
26
+ with torch.no_grad():
27
+ with tqdm(dataloader, desc="Inference", unit="batch") as t:
28
+ for batch in t:
29
+
30
+ input_ids = batch[0].to(device)
31
+ output = model(input_ids)[0]
32
+
33
+ if input_type == "cls_emb":
34
+ batch_embeddings = output[:, 0, :]
35
+ embeddings.append(batch_embeddings)
36
+ elif input_type == "channel_emb":
37
+ batch_embeddings = output[:, 1:, :]
38
+ embeddings.append(batch_embeddings)
39
+
40
+ output_total = torch.cat(embeddings, dim=0).float()
41
+
42
+ if visualization:
43
+ visualize_embeddings(output_total.view(output_total.size(0), -1),
44
+ labels,
45
+ method=visualization_method,
46
+ label="Embedding Space")
47
+ visualize_embeddings(data.view(data.size(0), -1),
48
+ labels,
49
+ method=visualization_method,
50
+ label="Original Space")
51
+
52
+ return output_total
input_preprocess.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 16:13:29 2024
4
+
5
+ This script generates preprocessed data from wireless communication scenarios,
6
+ including channel generation, patch generation, masking, and preparing raw
7
+ channels for the Transformer-based LWM model.
8
+
9
+ @author: Sadjad Alikhani
10
+ """
11
+ import numpy as np
12
+ import os
13
+ from tqdm import tqdm
14
+ import time
15
+ import pickle
16
+ import DeepMIMOv3
17
+ import torch
18
+ from collections import defaultdict
19
+ from utils import generate_gaussian_noise, plot_coverage
20
+ #%% Scenarios List
21
+ def scenarios_list():
22
+ scen_list = np.array([
23
+ 'city_0_newyork',
24
+ 'city_1_losangeles',
25
+ 'city_2_chicago',
26
+ 'city_3_houston',
27
+ 'city_4_phoenix',
28
+ 'city_5_philadelphia',
29
+ 'city_6_miami',
30
+ 'city_7_sandiego',
31
+ 'city_8_dallas',
32
+ 'city_9_sanfrancisco',
33
+ 'city_10_austin',
34
+ 'city_11_santaclara',
35
+ 'city_12_fortworth',
36
+ 'city_13_columbus',
37
+ 'city_14_charlotte',
38
+ 'city_15_indianapolis',
39
+ 'city_16_sanfrancisco',
40
+ 'city_17_seattle',
41
+ 'city_18_denver',
42
+ 'city_19_oklahoma',
43
+ 'asu_campus1_v1',
44
+ 'asu_campus1_v2',
45
+ 'asu_campus1_v3',
46
+ 'asu_campus1_v4',
47
+ 'asu_campus1_v5',
48
+ 'asu_campus1_v6',
49
+ 'asu_campus1_v7',
50
+ 'asu_campus1_v8',
51
+ 'asu_campus1_v9',
52
+ 'asu_campus1_v10',
53
+ 'asu_campus1_v11',
54
+ 'asu_campus1_v12',
55
+ 'asu_campus1_v13',
56
+ 'asu_campus1_v14',
57
+ 'asu_campus1_v15',
58
+ 'asu_campus1_v16',
59
+ 'asu_campus1_v17',
60
+ 'asu_campus1_v18',
61
+ 'asu_campus1_v19',
62
+ 'asu_campus1_v20',
63
+ 'Boston5G_3p5_v1',
64
+ 'Boston5G_3p5_v2',
65
+ 'Boston5G_3p5_v3',
66
+ 'Boston5G_3p5_v4',
67
+ 'Boston5G_3p5_v5',
68
+ 'Boston5G_3p5_v6',
69
+ 'Boston5G_3p5_v7',
70
+ 'Boston5G_3p5_v8',
71
+ 'Boston5G_3p5_v9',
72
+ 'Boston5G_3p5_v10',
73
+ 'Boston5G_3p5_v11',
74
+ 'Boston5G_3p5_v12',
75
+ 'Boston5G_3p5_v13',
76
+ 'Boston5G_3p5_v14',
77
+ 'Boston5G_3p5_v15',
78
+ 'Boston5G_3p5_v16',
79
+ 'Boston5G_3p5_v17',
80
+ 'Boston5G_3p5_v18',
81
+ 'Boston5G_3p5_v19',
82
+ 'Boston5G_3p5_v20',
83
+ 'O1_3p5_v1',
84
+ 'O1_3p5_v2',
85
+ 'O1_3p5_v3',
86
+ 'O1_3p5_v4',
87
+ 'O1_3p5_v5',
88
+ 'O1_3p5_v6',
89
+ 'O1_3p5_v7',
90
+ 'O1_3p5_v8',
91
+ 'O1_3p5_v9',
92
+ 'O1_3p5_v10',
93
+ 'O1_3p5_v11',
94
+ 'O1_3p5_v12',
95
+ 'O1_3p5_v13',
96
+ 'O1_3p5_v14',
97
+ 'O1_3p5_v15',
98
+ 'O1_3p5_v16',
99
+ 'O1_3p5_v17',
100
+ 'O1_3p5_v18',
101
+ 'O1_3p5_v19',
102
+ 'O1_3p5_v20',
103
+ 'asu_campus1',
104
+ 'O1_3p5',
105
+ 'Boston5G_3p5',
106
+ 'city_0_newyork_v16x64',
107
+ 'city_1_losangeles_v16x64',
108
+ 'city_2_chicago_v16x64',
109
+ 'city_3_houston_v16x64',
110
+ 'city_4_phoenix_v16x64',
111
+ 'city_5_philadelphia_v16x64',
112
+ 'city_6_miami_v16x64',
113
+ 'city_7_sandiego_v16x64',
114
+ 'city_8_dallas_v16x64',
115
+ 'city_9_sanfrancisco_v16x64'
116
+ ])
117
+ return scen_list
118
+ #%% Token Generation
119
+ def patch_gen(N_ROWS=4, N_COLUMNS=4, selected_scenario_names=None,
120
+ manual_data=None, bs_idxs=[1,2,3], load_data=False,
121
+ save_dir="data", task="LoS/NLoS Classification",
122
+ n_beams=64, o1_bs_idx=[4]):
123
+
124
+ os.makedirs(save_dir, exist_ok=True)
125
+
126
+ if manual_data is not None:
127
+ patches = patch_maker(np.expand_dims(np.array(manual_data), axis=1))
128
+ else:
129
+ deepmimo_data = []
130
+ for scenario_name in selected_scenario_names:
131
+ if "O1" in scenario_name: # make an exception for bs idxs of the o1 scenario
132
+ if o1_bs_idx is None:
133
+ bs_idxs = [4, 15]
134
+ else:
135
+ bs_idxs = o1_bs_idx
136
+ for bs_idx in bs_idxs:
137
+ if has_version_suffix(scenario_name) and bs_idx in [2,3]:
138
+ continue
139
+ if not load_data:
140
+ print(f"\nGenerating data for scenario: {scenario_name}, BS #{bs_idx}")
141
+ data, n_ant_bs, n_subcarriers = DeepMIMO_data_gen(scenario_name, bs_idx)
142
+ file_name = f"{save_dir}/{scenario_name}_ant{n_ant_bs}_sub{n_subcarriers}_bs{bs_idx}.npy"
143
+ np.save(file_name, data)
144
+ print(f"Data saved to {file_name}")
145
+ deepmimo_data.append(data)
146
+ else:
147
+ n_ant_bs, n_subcarriers = parametersv2(scenario_name, bs_idx)
148
+ print(f"\nLoading data for scenario: {scenario_name}, BS #{bs_idx}")
149
+ file_name = f"{save_dir}/{scenario_name}_ant{n_ant_bs}_sub{n_subcarriers}_bs{bs_idx}.npy"
150
+ data = np.load(file_name, allow_pickle=True).item()
151
+ print(f"Data loaded from {file_name}")
152
+ deepmimo_data.append(data)
153
+
154
+ cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(len(deepmimo_data))] #n_scenarios*n_bs_idxs
155
+ patches = [patch_maker(cleaned_deepmimo_data[scenario_idx], N_ROWS, N_COLUMNS) for scenario_idx in range(len(deepmimo_data))]
156
+ raw_chs = torch.tensor(cleaned_deepmimo_data[0]).squeeze(1)
157
+ raw_chs = raw_chs.view(raw_chs.size(0), -1)
158
+ raw_chs = torch.hstack((raw_chs.real, raw_chs.imag))
159
+
160
+ if task:
161
+ labels = [label_gen(task, deepmimo_data[scenario_idx], selected_scenario_names[scenario_idx], n_beams=n_beams) for scenario_idx in range(len(deepmimo_data))]
162
+ return patches, torch.tensor(labels[0]), raw_chs.view(raw_chs.size(0), -1)
163
+ else:
164
+ return patches, raw_chs.view(raw_chs.size(0), -1)
165
+ #%%
166
+ def tokenizer(selected_scenario_names,
167
+ bs_idxs=[1,2,3],
168
+ load_data=False,
169
+ task="LoS/NLoS Classification",
170
+ n_beams=64,
171
+ MAX_LEN=513,
172
+ masking_percent=.40,
173
+ mask=False,
174
+ seed=42,
175
+ snr=None):
176
+
177
+ patches, labels, raw_chs = patch_gen(
178
+ selected_scenario_names=selected_scenario_names,
179
+ bs_idxs=bs_idxs,
180
+ load_data=load_data,
181
+ task=task,
182
+ n_beams=n_beams
183
+ )
184
+
185
+ patches = [patch for patch_list in patches for patch in patch_list]
186
+ print("Total number of samples:", len(patches))
187
+
188
+ grouped_data = defaultdict(list) # Group samples by sequence length
189
+ grouped_data_2 = []
190
+
191
+ for user_idx in tqdm(range(len(patches)), desc="Processing items"):
192
+ patch_size = patches[user_idx].shape[1]
193
+ n_patches = patches[user_idx].shape[0]
194
+ n_masks_half = int(masking_percent * n_patches)
195
+
196
+ word2id = {
197
+ '[CLS]': 0.2 * np.ones((patch_size)),
198
+ '[MASK]': 0.1 * np.ones((patch_size))
199
+ }
200
+
201
+ sample = make_sample(
202
+ user_idx, patches, word2id, n_patches, n_masks_half, patch_size, MAX_LEN, mask=mask, seed=seed
203
+ )
204
+
205
+ if mask:
206
+ seq_length = len(sample[0])
207
+ grouped_data[seq_length].append(sample)
208
+ else:
209
+ grouped_data_2.append(sample)
210
+
211
+ if mask:
212
+ # Normalize keys to 0, 1, 2, ...
213
+ normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
214
+ else:
215
+ normalized_grouped_data = torch.stack(grouped_data_2, dim=0)
216
+ # normalized_grouped_data = grouped_data_2
217
+ if snr is not None:
218
+ normalized_grouped_data += generate_gaussian_noise(normalized_grouped_data, snr)
219
+ # normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
220
+
221
+ return normalized_grouped_data, labels, raw_chs
222
+ #%% REMOVE ZERO CHANNELS AND SCALE
223
+ def deepmimo_data_cleaning(deepmimo_data):
224
+ idxs = np.where(deepmimo_data['user']['LoS'] != -1)[0]
225
+ cleaned_deepmimo_data = deepmimo_data['user']['channel'][idxs]
226
+ return np.array(cleaned_deepmimo_data) * 1e6
227
+ #%%
228
+ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, MAX_LEN, mask=True, seed=None):
229
+
230
+ if seed is not None:
231
+ np.random.seed(seed)
232
+
233
+ # Step 1: Retrieve tokens and prepend [CLS]
234
+ tokens = patch[user_idx]
235
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
236
+
237
+ # Step 2: Mask real and imaginary patches
238
+ tokens_size = int(n_patches) # int(n_patches / 2)
239
+ masked_pos = np.random.choice(range(1, tokens_size), size=n_masks, replace=False)
240
+
241
+ masked_tokens = []
242
+ for pos in masked_pos:
243
+ original_masked_tokens = input_ids[pos].copy()
244
+ masked_tokens.append(original_masked_tokens)
245
+ if mask:
246
+ rnd_num = np.random.rand()
247
+ if rnd_num < 0.1:
248
+ input_ids[pos] = np.random.rand(patch_size) # Replace with random values
249
+ elif rnd_num < 0.9:
250
+ input_ids[pos] = word2id['[MASK]'] # Replace with [MASK]
251
+
252
+ if not mask:
253
+ return torch.tensor(input_ids)
254
+ else:
255
+ return [input_ids, masked_tokens, masked_pos]
256
+ #%% Patch GENERATION
257
+ def patch_maker(original_ch, patch_rows, patch_cols):
258
+ # Step 1: Remove the singleton channel dimension
259
+ n_samples, _, n_rows, n_cols = original_ch.shape # Unpack shape
260
+ original_ch = original_ch[:, 0] # Remove the singleton dimension
261
+
262
+ # Step 2: Split into real and imaginary parts and interleave them
263
+ flat_real = original_ch.real
264
+ flat_imag = original_ch.imag
265
+
266
+ # Interleave real and imaginary parts along the last axis
267
+ interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
268
+ interleaved[:, :, 0::2] = flat_real
269
+ interleaved[:, :, 1::2] = flat_imag
270
+
271
+ # Step 3: Compute the number of patches along rows and columns
272
+ n_patches_rows = int(np.ceil(n_rows / patch_rows))
273
+ n_patches_cols = int(np.ceil(n_cols / patch_cols))
274
+
275
+ # Step 4: Pad the matrix if necessary to make it divisible by patch size
276
+ padded_rows = n_patches_rows * patch_rows - n_rows
277
+ padded_cols = n_patches_cols * patch_cols - n_cols
278
+ if padded_rows > 0 or padded_cols > 0:
279
+ interleaved = np.pad(
280
+ interleaved,
281
+ ((0, 0), (0, padded_rows), (0, padded_cols * 2)), # Double padding for interleaved axis
282
+ mode='constant',
283
+ constant_values=0,
284
+ )
285
+
286
+ # Step 5: Create patches by dividing into blocks
287
+ n_samples, padded_rows, padded_cols = interleaved.shape
288
+ padded_cols //= 2 # Adjust for interleaving (real and imaginary parts count as one)
289
+ patches = []
290
+
291
+ for i in range(0, padded_rows, patch_rows):
292
+ for j in range(0, padded_cols, patch_cols):
293
+ patch = interleaved[:, i:i + patch_rows, j * 2:(j + patch_cols) * 2]
294
+ patches.append(patch.reshape(n_samples, -1)) # Flatten each patch
295
+
296
+ # Step 6: Stack patches to form the final array
297
+ patches = np.stack(patches, axis=1) # Shape: (num_samples, n_patches, patch_rows * patch_cols * 2)
298
+
299
+ return patches
300
+ #%% Data Generation for Scenario Areas
301
+ def DeepMIMO_data_gen(scenario, bs_idx):
302
+ import DeepMIMOv3
303
+ parameters, row_column_users = get_parameters(scenario, bs_idx)
304
+ deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
305
+
306
+ if "O1" in scenario:
307
+ hops = [2, 2]
308
+ else:
309
+ hops = [1, 1]
310
+
311
+ uniform_idxs = uniform_sampling(deepMIMO_dataset, hops, len(parameters['user_rows']),
312
+ users_per_row=row_column_users[scenario]['n_per_row'])
313
+ data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
314
+
315
+ n_ant_bs = parameters['bs_antenna']['shape'][0]
316
+ n_subcarriers = parameters['OFDM']['subcarriers']
317
+
318
+ return data, n_ant_bs, n_subcarriers
319
+ #%%
320
+ def parametersv2(scenario, bs_idx):
321
+ parameters, _ = get_parameters(scenario, bs_idx)
322
+ n_ant_bs = parameters['bs_antenna']['shape'][0]
323
+ n_subcarriers = parameters['OFDM']['subcarriers']
324
+ return n_ant_bs, n_subcarriers
325
+ #%%%
326
+ def get_parameters(scenario, bs_idx=1):
327
+
328
+ n_ant_ue = 1
329
+ scs = 30e3
330
+
331
+ row_column_users = scenario_prop()
332
+
333
+ parameters = DeepMIMOv3.default_params()
334
+ parameters['dataset_folder'] = './scenarios'
335
+ parameters['scenario'] = scenario.split("_v")[0]
336
+
337
+ n_ant_bs = row_column_users[scenario]['n_ant_bs']
338
+ n_subcarriers = row_column_users[scenario]['n_subcarriers']
339
+ parameters['active_BS'] = np.array([bs_idx])
340
+
341
+ if isinstance(row_column_users[scenario]['n_rows'], int):
342
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
343
+ else:
344
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
345
+ row_column_users[scenario]['n_rows'][1])
346
+
347
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
348
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
349
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
350
+ parameters['enable_BS2BS'] = False
351
+ parameters['OFDM']['subcarriers'] = n_subcarriers
352
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
353
+
354
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
355
+ parameters['num_paths'] = 20
356
+
357
+ return parameters, row_column_users
358
+ #%% Sampling and Data Selection
359
+ def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
360
+ cols = np.arange(users_per_row, step=sampling_div[0])
361
+ rows = np.arange(n_rows, step=sampling_div[1])
362
+ uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
363
+ return uniform_idxs
364
+
365
+ def select_by_idx(dataset, idxs):
366
+ dataset_t = [] # Trimmed dataset
367
+ for bs_idx in range(len(dataset)):
368
+ dataset_t.append({})
369
+ for key in dataset[bs_idx].keys():
370
+ dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
371
+ dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
372
+ return dataset_t
373
+ #%%
374
+ def inverse_patch_maker(patches, original_shape, patch_rows, patch_cols):
375
+ """
376
+ Reconstructs the original channel matrix from patches.
377
+
378
+ Args:
379
+ patches (numpy array): Patches of shape (num_samples, n_patches, patch_rows * patch_cols * 2).
380
+ original_shape (tuple): Original shape of the channel matrix (num_samples, 1, n_rows, n_cols).
381
+ patch_rows (int): Number of rows in each patch.
382
+ patch_cols (int): Number of columns in each patch.
383
+
384
+ Returns:
385
+ numpy array: Reconstructed complex-valued channel matrix of shape (num_samples, 1, n_rows, n_cols).
386
+ """
387
+ n_samples, n_patches, patch_size = patches.shape
388
+ _, _, n_rows, n_cols = original_shape
389
+
390
+ # Ensure patch dimensions match
391
+ assert patch_rows * patch_cols * 2 == patch_size, "Patch size mismatch with provided dimensions."
392
+
393
+ # Compute the number of patches along rows and columns
394
+ n_patches_rows = int(np.ceil(n_rows / patch_rows))
395
+ n_patches_cols = int(np.ceil(n_cols / patch_cols))
396
+
397
+ # Reassemble interleaved array from patches
398
+ interleaved = np.zeros((n_samples, n_patches_rows * patch_rows, n_patches_cols * patch_cols * 2), dtype=np.float32)
399
+ patch_idx = 0
400
+
401
+ for i in range(n_patches_rows):
402
+ for j in range(n_patches_cols):
403
+ patch = patches[:, patch_idx, :].reshape(n_samples, patch_rows, patch_cols * 2)
404
+ interleaved[:, i * patch_rows:(i + 1) * patch_rows, j * patch_cols * 2:(j + 1) * patch_cols * 2] = patch
405
+ patch_idx += 1
406
+
407
+ # Remove padding if necessary
408
+ interleaved = interleaved[:, :n_rows, :n_cols * 2]
409
+
410
+ # Separate real and imaginary parts
411
+ flat_real = interleaved[:, :, 0::2]
412
+ flat_imag = interleaved[:, :, 1::2]
413
+
414
+ # Reconstruct the complex-valued original channel
415
+ reconstructed = flat_real + 1j * flat_imag
416
+
417
+ # Add the singleton channel dimension back
418
+ reconstructed = reconstructed[:, np.newaxis, :, :] # Shape: (num_samples, 1, n_rows, n_cols)
419
+
420
+ return reconstructed
421
+ #%%
422
+ def label_gen(task, data, scenario, n_beams=64):
423
+
424
+ idxs = np.where(data['user']['LoS'] != -1)[0]
425
+
426
+ if task == 'LoS/NLoS Classification':
427
+ label = data['user']['LoS'][idxs]
428
+
429
+ losChs = np.where(data['user']['LoS'] == -1, np.nan, data['user']['LoS'])
430
+ plot_coverage(data['user']['location'], losChs, cbar_title='LoS status')
431
+
432
+ elif task == 'Beam Prediction':
433
+ parameters, row_column_users = get_parameters(scenario, bs_idx=1)
434
+ n_users = len(data['user']['channel'])
435
+ n_subbands = 1
436
+ fov = 180
437
+
438
+ # Setup Beamformers
439
+ beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
440
+
441
+ F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
442
+ phi=azi*np.pi/180,
443
+ kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
444
+ for azi in beam_angles])
445
+
446
+ full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
447
+ for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
448
+ if data['user']['LoS'][ue_idx] == -1:
449
+ full_dbm[:,:,ue_idx] = np.nan
450
+ else:
451
+ chs = F1 @ data['user']['channel'][ue_idx]
452
+ full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
453
+ full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
454
+
455
+ best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
456
+ best_beams = best_beams.astype(float)
457
+ best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
458
+ # max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
459
+
460
+ plot_coverage(data['user']['location'], best_beams, tx_pos=data['location'],
461
+ tx_ori=parameters['bs_antenna']['rotation']*np.pi/180,
462
+ cbar_title='Best beam index')
463
+
464
+ label = best_beams[idxs]
465
+
466
+ return label.astype(int)
467
+ #%%
468
+ def steering_vec(array, phi=0, theta=0, kd=np.pi):
469
+ idxs = DeepMIMOv3.ant_indices(array)
470
+ resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
471
+ return resp / np.linalg.norm(resp)
472
+ #%%
473
+ import re
474
+ def has_version_suffix(s):
475
+ pattern = r"_v([1-9]|1[0-9]|20)$"
476
+ return bool(re.search(pattern, s))
477
+ #%%
478
+ def scenario_prop():
479
+ row_column_users = {
480
+ 'city_0_newyork': {
481
+ 'n_rows': 109,
482
+ 'n_per_row': 291,
483
+ 'n_ant_bs': 8,
484
+ 'n_subcarriers': 32
485
+ },
486
+ 'city_1_losangeles': {
487
+ 'n_rows': 142,
488
+ 'n_per_row': 201,
489
+ 'n_ant_bs': 8,
490
+ 'n_subcarriers': 64
491
+ },
492
+ 'city_2_chicago': {
493
+ 'n_rows': 139,
494
+ 'n_per_row': 200,
495
+ 'n_ant_bs': 8,
496
+ 'n_subcarriers': 128
497
+ },
498
+ 'city_3_houston': {
499
+ 'n_rows': 154,
500
+ 'n_per_row': 202,
501
+ 'n_ant_bs': 8,
502
+ 'n_subcarriers': 256
503
+ },
504
+ 'city_4_phoenix': {
505
+ 'n_rows': 198,
506
+ 'n_per_row': 214,
507
+ 'n_ant_bs': 8,
508
+ 'n_subcarriers': 512
509
+ },
510
+ 'city_5_philadelphia': {
511
+ 'n_rows': 239,
512
+ 'n_per_row': 164,
513
+ 'n_ant_bs': 8,
514
+ 'n_subcarriers': 1024
515
+ },
516
+ 'city_6_miami': {
517
+ 'n_rows': 199,
518
+ 'n_per_row': 216 ,
519
+ 'n_ant_bs': 16,
520
+ 'n_subcarriers': 32
521
+ },
522
+ 'city_7_sandiego': {
523
+ 'n_rows': 207,
524
+ 'n_per_row': 176,
525
+ 'n_ant_bs': 16,
526
+ 'n_subcarriers': 64
527
+ },
528
+ 'city_8_dallas': {
529
+ 'n_rows': 207,
530
+ 'n_per_row': 190,
531
+ 'n_ant_bs': 16,
532
+ 'n_subcarriers': 128
533
+ },
534
+ 'city_9_sanfrancisco': {
535
+ 'n_rows': 196,
536
+ 'n_per_row': 206,
537
+ 'n_ant_bs': 16,
538
+ 'n_subcarriers': 256
539
+ },
540
+ 'city_10_austin': {
541
+ 'n_rows': 255,
542
+ 'n_per_row': 137,
543
+ 'n_ant_bs': 16,
544
+ 'n_subcarriers': 512
545
+ },
546
+ 'city_11_santaclara': {
547
+ 'n_rows': 117,
548
+ 'n_per_row': 285,
549
+ 'n_ant_bs': 32,
550
+ 'n_subcarriers': 32
551
+ },
552
+ 'city_12_fortworth': {
553
+ 'n_rows': 214,
554
+ 'n_per_row': 179,
555
+ 'n_ant_bs': 32,
556
+ 'n_subcarriers': 64
557
+ },
558
+ 'city_13_columbus': {
559
+ 'n_rows': 178,
560
+ 'n_per_row': 240,
561
+ 'n_ant_bs': 32,
562
+ 'n_subcarriers': 128
563
+ },
564
+ 'city_14_charlotte': {
565
+ 'n_rows': 216,
566
+ 'n_per_row': 177,
567
+ 'n_ant_bs': 32,
568
+ 'n_subcarriers': 256
569
+ },
570
+ 'city_15_indianapolis': {
571
+ 'n_rows': 200,
572
+ 'n_per_row': 196,
573
+ 'n_ant_bs': 64,
574
+ 'n_subcarriers': 32
575
+ },
576
+ 'city_16_sanfrancisco': {
577
+ 'n_rows': 201,
578
+ 'n_per_row': 208,
579
+ 'n_ant_bs': 64,
580
+ 'n_subcarriers': 64
581
+ },
582
+ 'city_17_seattle': {
583
+ 'n_rows': 185,
584
+ 'n_per_row': 205,
585
+ 'n_ant_bs': 64,
586
+ 'n_subcarriers': 128
587
+ },
588
+ 'city_18_denver': {
589
+ 'n_rows': 212,
590
+ 'n_per_row': 204,
591
+ 'n_ant_bs': 128,
592
+ 'n_subcarriers': 32
593
+ },
594
+ 'city_19_oklahoma': {
595
+ 'n_rows': 204,
596
+ 'n_per_row': 188,
597
+ 'n_ant_bs': 128,
598
+ 'n_subcarriers': 64
599
+ },
600
+ 'asu_campus1_v1': {
601
+ 'n_rows': [0, 1*int(321/20)],
602
+ 'n_per_row': 411,
603
+ 'n_ant_bs': 8,
604
+ 'n_subcarriers': 32
605
+ },
606
+ 'asu_campus1_v2': {
607
+ 'n_rows': [1*int(321/20), 2*int(321/20)],
608
+ 'n_per_row': 411,
609
+ 'n_ant_bs': 8,
610
+ 'n_subcarriers': 64
611
+ },
612
+ 'asu_campus1_v3': {
613
+ 'n_rows': [2*int(321/20), 3*int(321/20)],
614
+ 'n_per_row': 411,
615
+ 'n_ant_bs': 8,
616
+ 'n_subcarriers': 128
617
+ },
618
+ 'asu_campus1_v4': {
619
+ 'n_rows': [3*int(321/20), 4*int(321/20)],
620
+ 'n_per_row': 411,
621
+ 'n_ant_bs': 8,
622
+ 'n_subcarriers': 256
623
+ },
624
+ 'asu_campus1_v5': {
625
+ 'n_rows': [4*int(321/20), 5*int(321/20)],
626
+ 'n_per_row': 411,
627
+ 'n_ant_bs': 8,
628
+ 'n_subcarriers': 512
629
+ },
630
+ 'asu_campus1_v6': {
631
+ 'n_rows': [5*int(321/20), 6*int(321/20)],
632
+ 'n_per_row': 411,
633
+ 'n_ant_bs': 8,
634
+ 'n_subcarriers': 1024
635
+ },
636
+ 'asu_campus1_v7': {
637
+ 'n_rows': [6*int(321/20), 7*int(321/20)],
638
+ 'n_per_row': 411,
639
+ 'n_ant_bs': 16,
640
+ 'n_subcarriers': 32
641
+ },
642
+ 'asu_campus1_v8': {
643
+ 'n_rows': [7*int(321/20), 8*int(321/20)],
644
+ 'n_per_row': 411,
645
+ 'n_ant_bs':16,
646
+ 'n_subcarriers': 64
647
+ },
648
+ 'asu_campus1_v9': {
649
+ 'n_rows': [8*int(321/20), 9*int(321/20)],
650
+ 'n_per_row': 411,
651
+ 'n_ant_bs': 16,
652
+ 'n_subcarriers': 128
653
+ },
654
+ 'asu_campus1_v10': {
655
+ 'n_rows': [9*int(321/20), 10*int(321/20)],
656
+ 'n_per_row': 411,
657
+ 'n_ant_bs': 16,
658
+ 'n_subcarriers': 256
659
+ },
660
+ 'asu_campus1_v11': {
661
+ 'n_rows': [10*int(321/20), 11*int(321/20)],
662
+ 'n_per_row': 411,
663
+ 'n_ant_bs': 16,
664
+ 'n_subcarriers': 512
665
+ },
666
+ 'asu_campus1_v12': {
667
+ 'n_rows': [11*int(321/20), 12*int(321/20)],
668
+ 'n_per_row': 411,
669
+ 'n_ant_bs': 32,
670
+ 'n_subcarriers': 32
671
+ },
672
+ 'asu_campus1_v13': {
673
+ 'n_rows': [12*int(321/20), 13*int(321/20)],
674
+ 'n_per_row': 411,
675
+ 'n_ant_bs': 32,
676
+ 'n_subcarriers': 64
677
+ },
678
+ 'asu_campus1_v14': {
679
+ 'n_rows': [13*int(321/20), 14*int(321/20)],
680
+ 'n_per_row': 411,
681
+ 'n_ant_bs': 32,
682
+ 'n_subcarriers': 128
683
+ },
684
+ 'asu_campus1_v15': {
685
+ 'n_rows': [14*int(321/20), 15*int(321/20)],
686
+ 'n_per_row': 411,
687
+ 'n_ant_bs': 32,
688
+ 'n_subcarriers': 256
689
+ },
690
+ 'asu_campus1_v16': {
691
+ 'n_rows': [15*int(321/20), 16*int(321/20)],
692
+ 'n_per_row': 411,
693
+ 'n_ant_bs': 64,
694
+ 'n_subcarriers': 32
695
+ },
696
+ 'asu_campus1_v17': {
697
+ 'n_rows': [16*int(321/20), 17*int(321/20)],
698
+ 'n_per_row': 411,
699
+ 'n_ant_bs': 64,
700
+ 'n_subcarriers': 64
701
+ },
702
+ 'asu_campus1_v18': {
703
+ 'n_rows': [17*int(321/20), 18*int(321/20)],
704
+ 'n_per_row': 411,
705
+ 'n_ant_bs': 64,
706
+ 'n_subcarriers': 128
707
+ },
708
+ 'asu_campus1_v19': {
709
+ 'n_rows': [18*int(321/20), 19*int(321/20)],
710
+ 'n_per_row': 411,
711
+ 'n_ant_bs': 128,
712
+ 'n_subcarriers': 32
713
+ },
714
+ 'asu_campus1_v20': {
715
+ 'n_rows': [19*int(321/20), 20*int(321/20)],
716
+ 'n_per_row': 411,
717
+ 'n_ant_bs': 128,
718
+ 'n_subcarriers': 64
719
+ },
720
+ 'Boston5G_3p5_v1': {
721
+ 'n_rows': [812, 812 + 1*int((1622-812)/20)],
722
+ 'n_per_row': 595,
723
+ 'n_ant_bs': 8,
724
+ 'n_subcarriers': 32
725
+ },
726
+ 'Boston5G_3p5_v2': {
727
+ 'n_rows': [812 + 1*int((1622-812)/20), 812 + 2*int((1622-812)/20)],
728
+ 'n_per_row': 595,
729
+ 'n_ant_bs': 8,
730
+ 'n_subcarriers': 64
731
+ },
732
+ 'Boston5G_3p5_v3': {
733
+ 'n_rows': [812 + 2*int((1622-812)/20), 812 + 3*int((1622-812)/20)],
734
+ 'n_per_row': 595,
735
+ 'n_ant_bs': 8,
736
+ 'n_subcarriers': 128
737
+ },
738
+ 'Boston5G_3p5_v4': {
739
+ 'n_rows': [812 + 3*int((1622-812)/20), 812 + 4*int((1622-812)/20)],
740
+ 'n_per_row': 595,
741
+ 'n_ant_bs': 8,
742
+ 'n_subcarriers': 256
743
+ },
744
+ 'Boston5G_3p5_v5': {
745
+ 'n_rows': [812 + 4*int((1622-812)/20), 812 + 5*int((1622-812)/20)],
746
+ 'n_per_row': 595,
747
+ 'n_ant_bs': 8,
748
+ 'n_subcarriers': 512
749
+ },
750
+ 'Boston5G_3p5_v6': {
751
+ 'n_rows': [812 + 5*int((1622-812)/20), 812 + 6*int((1622-812)/20)],
752
+ 'n_per_row': 595,
753
+ 'n_ant_bs': 8,
754
+ 'n_subcarriers': 1024
755
+ },
756
+ 'Boston5G_3p5_v7': {
757
+ 'n_rows': [812 + 6*int((1622-812)/20), 812 + 7*int((1622-812)/20)],
758
+ 'n_per_row': 595,
759
+ 'n_ant_bs': 16,
760
+ 'n_subcarriers': 32
761
+ },
762
+ 'Boston5G_3p5_v8': {
763
+ 'n_rows': [812 + 7*int((1622-812)/20), 812 + 8*int((1622-812)/20)],
764
+ 'n_per_row': 595,
765
+ 'n_ant_bs':16,
766
+ 'n_subcarriers': 64
767
+ },
768
+ 'Boston5G_3p5_v9': {
769
+ 'n_rows': [812 + 8*int((1622-812)/20), 812 + 9*int((1622-812)/20)],
770
+ 'n_per_row': 595,
771
+ 'n_ant_bs': 16,
772
+ 'n_subcarriers': 128
773
+ },
774
+ 'Boston5G_3p5_v10': {
775
+ 'n_rows': [812 + 9*int((1622-812)/20), 812 + 10*int((1622-812)/20)],
776
+ 'n_per_row': 595,
777
+ 'n_ant_bs': 16,
778
+ 'n_subcarriers': 256
779
+ },
780
+ 'Boston5G_3p5_v11': {
781
+ 'n_rows': [812 + 10*int((1622-812)/20), 812 + 11*int((1622-812)/20)],
782
+ 'n_per_row': 595,
783
+ 'n_ant_bs': 16,
784
+ 'n_subcarriers': 512
785
+ },
786
+ 'Boston5G_3p5_v12': {
787
+ 'n_rows': [812 + 11*int((1622-812)/20), 812 + 12*int((1622-812)/20)],
788
+ 'n_per_row': 595,
789
+ 'n_ant_bs': 32,
790
+ 'n_subcarriers': 32
791
+ },
792
+ 'Boston5G_3p5_v13': {
793
+ 'n_rows': [812 + 12*int((1622-812)/20), 812 + 13*int((1622-812)/20)],
794
+ 'n_per_row': 595,
795
+ 'n_ant_bs': 32,
796
+ 'n_subcarriers': 64
797
+ },
798
+ 'Boston5G_3p5_v14': {
799
+ 'n_rows': [812 + 13*int((1622-812)/20), 812 + 14*int((1622-812)/20)],
800
+ 'n_per_row': 595,
801
+ 'n_ant_bs': 32,
802
+ 'n_subcarriers': 128
803
+ },
804
+ 'Boston5G_3p5_v15': {
805
+ 'n_rows': [812 + 14*int((1622-812)/20), 812 + 15*int((1622-812)/20)],
806
+ 'n_per_row': 595,
807
+ 'n_ant_bs': 32,
808
+ 'n_subcarriers': 256
809
+ },
810
+ 'Boston5G_3p5_v16': {
811
+ 'n_rows': [812 + 15*int((1622-812)/20), 812 + 16*int((1622-812)/20)],
812
+ 'n_per_row': 595,
813
+ 'n_ant_bs': 64,
814
+ 'n_subcarriers': 32
815
+ },
816
+ 'Boston5G_3p5_v17': {
817
+ 'n_rows': [812 + 16*int((1622-812)/20), 812 + 17*int((1622-812)/20)],
818
+ 'n_per_row': 595,
819
+ 'n_ant_bs': 64,
820
+ 'n_subcarriers': 64
821
+ },
822
+ 'Boston5G_3p5_v18': {
823
+ 'n_rows': [812 + 17*int((1622-812)/20), 812 + 18*int((1622-812)/20)],
824
+ 'n_per_row': 595,
825
+ 'n_ant_bs': 64,
826
+ 'n_subcarriers': 128
827
+ },
828
+ 'Boston5G_3p5_v19': {
829
+ 'n_rows': [812 + 18*int((1622-812)/20), 812 + 19*int((1622-812)/20)],
830
+ 'n_per_row': 595,
831
+ 'n_ant_bs': 128,
832
+ 'n_subcarriers': 32
833
+ },
834
+ 'Boston5G_3p5_v20': {
835
+ 'n_rows': [812 + 19*int((1622-812)/20), 812 + 20*int((1622-812)/20)],
836
+ 'n_per_row': 595,
837
+ 'n_ant_bs': 128,
838
+ 'n_subcarriers': 64
839
+ },
840
+ 'O1_3p5_v1': {
841
+ 'n_rows': [0*int(3852/12), 1*int(3852/12)],
842
+ 'n_per_row': 181,
843
+ 'n_ant_bs': 8,
844
+ 'n_subcarriers': 32
845
+ },
846
+ 'O1_3p5_v2': {
847
+ 'n_rows': [1*int(3852/12), 2*int(3852/12)],
848
+ 'n_per_row': 181,
849
+ 'n_ant_bs': 8,
850
+ 'n_subcarriers': 64
851
+ },
852
+ 'O1_3p5_v3': {
853
+ 'n_rows': [2*int(3852/12), 3*int(3852/12)],
854
+ 'n_per_row': 181,
855
+ 'n_ant_bs': 8,
856
+ 'n_subcarriers': 128
857
+ },
858
+ 'O1_3p5_v4': {
859
+ 'n_rows': [3*int(3852/12), 4*int(3852/12)],
860
+ 'n_per_row': 181,
861
+ 'n_ant_bs': 8,
862
+ 'n_subcarriers': 256
863
+ },
864
+ 'O1_3p5_v5': {
865
+ 'n_rows': [4*int(3852/12), 5*int(3852/12)],
866
+ 'n_per_row': 181,
867
+ 'n_ant_bs': 8,
868
+ 'n_subcarriers': 512
869
+ },
870
+ 'O1_3p5_v6': {
871
+ 'n_rows': [5*int(3852/12), 6*int(3852/12)],
872
+ 'n_per_row': 181,
873
+ 'n_ant_bs': 8,
874
+ 'n_subcarriers': 1024
875
+ },
876
+ 'O1_3p5_v7': {
877
+ 'n_rows': [6*int(3852/12), 7*int(3852/12)],
878
+ 'n_per_row': 181,
879
+ 'n_ant_bs': 16,
880
+ 'n_subcarriers': 32
881
+ },
882
+ 'O1_3p5_v8': {
883
+ 'n_rows': [7*int(3852/12), 8*int(3852/12)],
884
+ 'n_per_row': 181,
885
+ 'n_ant_bs': 16,
886
+ 'n_subcarriers': 64
887
+ },
888
+ 'O1_3p5_v9': {
889
+ 'n_rows': [8*int(3852/12), 9*int(3852/12)],
890
+ 'n_per_row': 181,
891
+ 'n_ant_bs': 16,
892
+ 'n_subcarriers': 128
893
+ },
894
+ 'O1_3p5_v10': {
895
+ 'n_rows': [9*int(3852/12), 10*int(3852/12)],
896
+ 'n_per_row': 181,
897
+ 'n_ant_bs': 16,
898
+ 'n_subcarriers': 256
899
+ },
900
+ 'O1_3p5_v11': {
901
+ 'n_rows': [10*int(3852/12), 11*int(3852/12)],
902
+ 'n_per_row': 181,
903
+ 'n_ant_bs': 16,
904
+ 'n_subcarriers': 512
905
+ },
906
+ 'O1_3p5_v12': {
907
+ 'n_rows': [11*int(3852/12), 12*int(3852/12)],
908
+ 'n_per_row': 181,
909
+ 'n_ant_bs': 32,
910
+ 'n_subcarriers': 32
911
+ },
912
+ 'O1_3p5_v13': {
913
+ 'n_rows': [12*int(3852/12)+0*int(1351/10), 12*int(3852/12)+1*int(1351/10)],
914
+ 'n_per_row': 361,
915
+ 'n_ant_bs': 32,
916
+ 'n_subcarriers': 64
917
+ },
918
+ 'O1_3p5_v14': {
919
+ 'n_rows': [12*int(3852/12)+1*int(1351/10), 12*int(3852/12)+2*int(1351/10)],
920
+ 'n_per_row': 181,
921
+ 'n_ant_bs': 32,
922
+ 'n_subcarriers': 128
923
+ },
924
+ 'O1_3p5_v15': {
925
+ 'n_rows': [12*int(3852/12)+2*int(1351/10), 12*int(3852/12)+3*int(1351/10)],
926
+ 'n_per_row': 181,
927
+ 'n_ant_bs': 32,
928
+ 'n_subcarriers': 256
929
+ },
930
+ 'O1_3p5_v16': {
931
+ 'n_rows': [12*int(3852/12)+3*int(1351/10), 12*int(3852/12)+4*int(1351/10)],
932
+ 'n_per_row': 181,
933
+ 'n_ant_bs': 64,
934
+ 'n_subcarriers': 32
935
+ },
936
+ 'O1_3p5_v17': {
937
+ 'n_rows': [12*int(3852/12)+4*int(1351/10), 12*int(3852/12)+5*int(1351/10)],
938
+ 'n_per_row': 181,
939
+ 'n_ant_bs': 64,
940
+ 'n_subcarriers': 64
941
+ },
942
+ 'O1_3p5_v18': {
943
+ 'n_rows': [12*int(3852/12)+5*int(1351/10), 12*int(3852/12)+6*int(1351/10)],
944
+ 'n_per_row': 181,
945
+ 'n_ant_bs': 64,
946
+ 'n_subcarriers': 128
947
+ },
948
+ 'O1_3p5_v19': {
949
+ 'n_rows': [12*int(3852/12)+6*int(1351/10), 12*int(3852/12)+7*int(1351/10)],
950
+ 'n_per_row': 181,
951
+ 'n_ant_bs': 128,
952
+ 'n_subcarriers': 32
953
+ },
954
+ 'O1_3p5_v20': {
955
+ 'n_rows': [12*int(3852/12)+7*int(1351/10), 12*int(3852/12)+8*int(1351/10)],
956
+ 'n_per_row': 181,
957
+ 'n_ant_bs': 128,
958
+ 'n_subcarriers': 64
959
+ },
960
+ 'city_0_newyork_v16x64': {
961
+ 'n_rows': 109,
962
+ 'n_per_row': 291,
963
+ 'n_ant_bs': 16,
964
+ 'n_subcarriers': 64
965
+ },
966
+ 'city_1_losangeles_v16x64': {
967
+ 'n_rows': 142,
968
+ 'n_per_row': 201,
969
+ 'n_ant_bs': 16,
970
+ 'n_subcarriers': 64
971
+ },
972
+ 'city_2_chicago_v16x64': {
973
+ 'n_rows': 139,
974
+ 'n_per_row': 200,
975
+ 'n_ant_bs': 16,
976
+ 'n_subcarriers': 64
977
+ },
978
+ 'city_3_houston_v16x64': {
979
+ 'n_rows': 154,
980
+ 'n_per_row': 202,
981
+ 'n_ant_bs': 16,
982
+ 'n_subcarriers': 64
983
+ },
984
+ 'city_4_phoenix_v16x64': {
985
+ 'n_rows': 198,
986
+ 'n_per_row': 214,
987
+ 'n_ant_bs': 16,
988
+ 'n_subcarriers': 64
989
+ },
990
+ 'city_5_philadelphia_v16x64': {
991
+ 'n_rows': 239,
992
+ 'n_per_row': 164,
993
+ 'n_ant_bs': 16,
994
+ 'n_subcarriers': 64
995
+ },
996
+ 'city_6_miami_v16x64': {
997
+ 'n_rows': 199,
998
+ 'n_per_row': 216,
999
+ 'n_ant_bs': 16,
1000
+ 'n_subcarriers': 64
1001
+ },
1002
+ 'city_7_sandiego_v16x64': {
1003
+ 'n_rows': 207,
1004
+ 'n_per_row': 176,
1005
+ 'n_ant_bs': 16,
1006
+ 'n_subcarriers': 64
1007
+ },
1008
+ 'city_8_dallas_v16x64': {
1009
+ 'n_rows': 207,
1010
+ 'n_per_row': 190,
1011
+ 'n_ant_bs': 16,
1012
+ 'n_subcarriers': 64
1013
+ },
1014
+ 'city_9_sanfrancisco_v16x64': {
1015
+ 'n_rows': 196,
1016
+ 'n_per_row': 206,
1017
+ 'n_ant_bs': 16,
1018
+ 'n_subcarriers': 64
1019
+ }}
1020
+ return row_column_users
lwm_model.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 19:23:54 2024
4
+
5
+ This script defines the LWM model architecture.
6
+
7
+ @author: Sadjad Alikhani
8
+ """
9
+ #%%
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ #%%
15
+ class LayerNormalization(nn.Module):
16
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
17
+ super().__init__()
18
+ self.eps = eps
19
+ self.alpha = nn.Parameter(torch.ones(d_model))
20
+ self.bias = nn.Parameter(torch.zeros(d_model))
21
+
22
+ def forward(self, x):
23
+ mean = x.mean(dim=-1, keepdim=True)
24
+ std = x.std(dim=-1, keepdim=True)
25
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
26
+
27
+
28
+ class Embedding(nn.Module):
29
+ def __init__(self, element_length, d_model, max_len=513):
30
+ super().__init__()
31
+ self.element_length = element_length
32
+ self.d_model = d_model
33
+ self.proj = nn.Linear(element_length, d_model)
34
+ self.pos_embed = nn.Embedding(max_len, d_model)
35
+ self.norm = LayerNormalization(d_model)
36
+
37
+ def forward(self, x):
38
+ seq_len = x.size(1)
39
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
40
+ pos_encodings = self.pos_embed(pos)
41
+ tok_emb = self.proj(x.float())
42
+ embedding = tok_emb + pos_encodings
43
+ return self.norm(embedding)
44
+
45
+
46
+ class ScaledDotProductAttention(nn.Module):
47
+ def __init__(self, d_k):
48
+ super().__init__()
49
+ self.d_k = d_k
50
+
51
+ def forward(self, Q, K, V):
52
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
53
+ attn = F.softmax(scores, dim=-1)
54
+ context = torch.matmul(attn, V)
55
+ return context, attn
56
+
57
+
58
+ class MultiHeadAttention(nn.Module):
59
+ def __init__(self, d_model, n_heads, dropout):
60
+ super().__init__()
61
+ self.d_k = d_model // n_heads
62
+ self.d_v = d_model // n_heads
63
+ self.n_heads = n_heads
64
+ self.W_Q = nn.Linear(d_model, self.d_k * n_heads)
65
+ self.W_K = nn.Linear(d_model, self.d_k * n_heads)
66
+ self.W_V = nn.Linear(d_model, self.d_v * n_heads)
67
+ self.linear = nn.Linear(n_heads * self.d_v, d_model)
68
+ self.dropout = nn.Dropout(dropout)
69
+ self.scaled_dot_attn = ScaledDotProductAttention(self.d_k)
70
+
71
+ def forward(self, Q, K, V):
72
+ residual, batch_size = Q, Q.size(0)
73
+ q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
74
+ k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
75
+ v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
76
+
77
+ context, attn = self.scaled_dot_attn(q_s, k_s, v_s)
78
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
79
+ output = self.linear(output)
80
+ return residual + self.dropout(output), attn
81
+
82
+
83
+ class PoswiseFeedForwardNet(nn.Module):
84
+ def __init__(self, d_model, d_ff, dropout):
85
+ super().__init__()
86
+ self.fc1 = nn.Linear(d_model, d_ff)
87
+ self.fc2 = nn.Linear(d_ff, d_model)
88
+ self.dropout = nn.Dropout(dropout)
89
+
90
+ def forward(self, x):
91
+ return self.fc2(self.dropout(F.relu(self.fc1(x))))
92
+
93
+
94
+ class EncoderLayer(nn.Module):
95
+ def __init__(self, d_model, n_heads, d_ff, dropout):
96
+ super().__init__()
97
+ self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
98
+ self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff, dropout)
99
+ self.norm1 = LayerNormalization(d_model)
100
+ self.norm2 = LayerNormalization(d_model)
101
+
102
+ def forward(self, enc_inputs):
103
+ # Self-Attention with Add & Norm
104
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
105
+ attn_outputs = self.norm1(enc_inputs + attn_outputs) # Add & Norm
106
+
107
+ # Feed-Forward with Add & Norm
108
+ ff_outputs = self.pos_ffn(attn_outputs)
109
+ enc_outputs = self.norm2(attn_outputs + ff_outputs) # Add & Norm
110
+
111
+ return enc_outputs, attn
112
+
113
+
114
+ class lwm(nn.Module):
115
+ def __init__(self, element_length=32, d_model=128, n_layers=12, max_len=513, n_heads=8, dropout=0.1):
116
+ super().__init__()
117
+ self.embedding = Embedding(element_length, d_model, max_len)
118
+ self.layers = nn.ModuleList(
119
+ [EncoderLayer(d_model, n_heads, d_model*4, dropout) for _ in range(n_layers)]
120
+ )
121
+ self.linear = nn.Linear(d_model, d_model)
122
+ self.norm = LayerNormalization(d_model)
123
+
124
+ embed_weight = self.embedding.proj.weight
125
+ _, n_dim = embed_weight.size()
126
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
127
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
128
+
129
+ @classmethod
130
+ def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
131
+ model = cls().to(device)
132
+ model.load_state_dict(torch.load(ckpt_name, map_location=device))
133
+ print(f"Model loaded successfully from {ckpt_name}")
134
+ return model
135
+
136
+ def forward(self, input_ids, masked_pos=None):
137
+ # Step 1: Embedding
138
+ output = self.embedding(input_ids)
139
+ attention_maps = []
140
+
141
+ # Step 2: Pass through Encoder Layers
142
+ for layer in self.layers:
143
+ output, attn = layer(output)
144
+ attention_maps.append(attn)
145
+
146
+ # If masked_pos is provided, perform masked token prediction
147
+ if masked_pos is not None:
148
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
149
+ h_masked = torch.gather(output, 1, masked_pos)
150
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
151
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
152
+ return logits_lm, output, attention_maps
153
+ else:
154
+ return output, attention_maps
main.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sat Dec 21 13:24:21 2024
4
+
5
+ This script pre-trains the LWM model
6
+
7
+ @author: salikha4
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils.data import random_split
12
+ from input_preprocess import tokenizer, scenarios_list
13
+ from utils import create_dataloader, count_parameters
14
+ import numpy as np
15
+ import lwm_model
16
+ from torch.optim.lr_scheduler import CosineAnnealingLR
17
+ from torch.optim.lr_scheduler import LambdaLR
18
+ from torch.optim import AdamW
19
+ from train import train_lwm
20
+ import warnings
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ #%% SETTINGS
23
+ EPOCHS = 50
24
+ BATCH_SIZE = 128
25
+ VAL_BATCH_SIZE = 64
26
+ WARMUP_EPOCHS = 5
27
+ BASE_LR = 5e-4
28
+ N_ROWS = 4
29
+ N_COLUMNS = 4
30
+ ELEMENT_LENGTH = N_ROWS*N_COLUMNS*2
31
+ D_MODEL = 128
32
+ MAX_LEN = 513
33
+ N_LAYERS = 12
34
+ WEIGHT_DECAY = 0.05
35
+ BETA1 = 0.9
36
+ BETA2 = 0.999
37
+ MASK_PERCENT = 0.40
38
+ N_HEADS = 8
39
+ DROPOUT = 0.1
40
+ #%% GENERATE DATASET
41
+ bs_idxs = [1, 2, 3]
42
+ selected_scenario_names = scenarios_list()[:80]
43
+ preprocessed_data = tokenizer(
44
+ selected_scenario_names,
45
+ MAX_LEN,
46
+ masking_percent=MASK_PERCENT,
47
+ mask=True,
48
+ seed=42
49
+ )
50
+ #%% SPLIT DATASET
51
+ SEED = 42
52
+ torch.manual_seed(SEED)
53
+ np.random.seed(SEED)
54
+ train_ratio = 0.8
55
+ val_ratio = 0.2
56
+ train_data = {}
57
+ val_data = {}
58
+ test_data = {}
59
+ for key, samples in preprocessed_data.items():
60
+ print(f"key: {key}")
61
+ total_samples = len(samples)
62
+ train_size = int(train_ratio * total_samples)
63
+ val_size = int(val_ratio * total_samples)
64
+ test_size = total_samples - val_size - train_size
65
+
66
+ train_data[key], val_data[key], test_data[key] = random_split(
67
+ samples, [train_size, val_size, test_size]
68
+ )
69
+ train_loaders = create_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
70
+ val_loaders = create_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
71
+ #%% INITIALIZE MODEL
72
+ load_model = True
73
+ gpu_ids = [0]
74
+ device = torch.device("cuda:0")
75
+ model = lwm_model.lwm().to(device)
76
+
77
+ if load_model:
78
+ model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
79
+ state_dict = torch.load(f"models/{model_name}", map_location=device)
80
+ new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
81
+ model.load_state_dict(new_state_dict)
82
+
83
+ model = nn.DataParallel(model, gpu_ids)
84
+ print(f"Model loaded successfully on GPU {device.index}")
85
+
86
+ n_parameters = count_parameters(model)
87
+ print(f"Number of trainable parameters: {n_parameters:,}")
88
+ #%% OPTIMIZER AND SCHEDULER
89
+ BASE_LR = 5e-5
90
+ MIN_LR = 1e-8
91
+ TOTAL_STEPS = sum(len(loader) for loader in train_loaders.values()) * EPOCHS
92
+ WARMUP_STEPS = sum(len(loader) for loader in train_loaders.values()) * WARMUP_EPOCHS
93
+
94
+ optimizer = AdamW(
95
+ model.parameters(),
96
+ lr=BASE_LR,
97
+ betas=(BETA1, BETA2),
98
+ weight_decay=WEIGHT_DECAY
99
+ )
100
+ def lr_lambda(current_step):
101
+ if current_step < WARMUP_STEPS:
102
+ # Linear warmup
103
+ return current_step / WARMUP_STEPS
104
+ else:
105
+ # Scaled cosine decay
106
+ scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
107
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
108
+ return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
109
+
110
+ scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
111
+ #%% PRE-TRAIN THE MODEL
112
+ pretrained_model = train_lwm(
113
+ model,
114
+ train_loaders,
115
+ val_loaders,
116
+ optimizer,
117
+ scheduler,
118
+ EPOCHS,
119
+ device=device
120
+ )
models/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:485611f1a0f819f9c673827b8e613887b39672e97072bd7a412866b49d8dd40f
3
+ size 9960738
train.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Dec 20 09:32:12 2024
4
+
5
+ This script contains the LWM pre-training and task-specific fine-tuning functions.
6
+
7
+ @author: Sadjad Alikhani
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ from tqdm import tqdm
12
+ import matplotlib.pyplot as plt
13
+ import os
14
+ import csv
15
+ from utils import count_parameters
16
+ import time
17
+ #%% LOSS FUNCTION
18
+ def nmse_loss(y_pred, y_true):
19
+ y_pred_flat = y_pred.view(y_pred.size(0), -1)
20
+ y_true_flat = y_true.view(y_true.size(0), -1)
21
+ mse = torch.sum((y_true_flat - y_pred_flat)**2, dim=-1)
22
+ normalization = torch.sum(y_true_flat**2, dim=-1)
23
+ return mse / normalization
24
+ #%%
25
+ def train_lwm(model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv"):
26
+
27
+ if not os.path.exists(save_dir):
28
+ os.makedirs(save_dir)
29
+
30
+ # Initialize CSV log
31
+ if not os.path.exists(log_file):
32
+ with open(log_file, mode='w', newline='') as file:
33
+ writer = csv.writer(file)
34
+ writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
35
+
36
+ train_nmse_losses = []
37
+ val_nmse_losses = []
38
+ best_val_nmse = float('inf')
39
+
40
+ for epoch in range(epochs):
41
+ model.train()
42
+ train_nmse = 0.0
43
+ train_samples = 0
44
+
45
+ # Training loop across all buckets
46
+ print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
47
+ for length, train_loader in train_loaders.items():
48
+ print(f"Processing sequences of length {length}")
49
+ with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
50
+ for batch in t:
51
+ # train_batches += 1
52
+ optimizer.zero_grad()
53
+
54
+ # Move data to device
55
+ input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
56
+
57
+ # Forward pass
58
+ logits_lm, _, _ = model(input_ids, masked_pos)
59
+
60
+ # Compute NMSE
61
+ loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
62
+ loss.backward()
63
+ optimizer.step()
64
+ scheduler.step()
65
+
66
+ train_nmse += loss.item()
67
+ train_samples += input_ids.shape[0]
68
+
69
+ # Update progress bar
70
+ t.set_postfix({"nmse": train_nmse/train_samples, "lr": scheduler.get_last_lr()[0]})
71
+
72
+ # Average NMSE across training batches
73
+ train_nmse /= max(train_samples, 1)
74
+ train_nmse_losses.append(train_nmse)
75
+
76
+ if epoch % 2 == 0:
77
+ # Validation loop across all buckets
78
+ model.eval()
79
+ val_nmse = 0.0
80
+ val_samples = 0
81
+ with torch.no_grad():
82
+ print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
83
+ for length, val_loader in val_loaders.items():
84
+ print(f"Processing sequences of length {length}")
85
+ with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
86
+ for batch in t:
87
+ # val_batches += 1
88
+
89
+ # Move data to device
90
+ input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
91
+
92
+ # Forward pass
93
+ logits_lm, _, _ = model(input_ids, masked_pos)
94
+
95
+ # Compute NMSE
96
+ loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
97
+ val_nmse += loss.item()
98
+ val_samples += input_ids.shape[0]
99
+
100
+ # Update progress bar
101
+ t.set_postfix({"nmse": val_nmse/val_samples})
102
+
103
+ # Average NMSE across validation batches
104
+ val_nmse /= max(val_samples, 1)
105
+ val_nmse_losses.append(val_nmse)
106
+
107
+ # Save model if validation NMSE improves
108
+ is_best_model = False
109
+ if val_nmse < best_val_nmse:
110
+ best_val_nmse = val_nmse
111
+ model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
112
+ torch.save(model.state_dict(), model_path)
113
+ print(f"Model saved: {model_path}")
114
+ is_best_model = True
115
+
116
+ # Log the results
117
+ print(f" Train NMSE: {train_nmse:.4f}")
118
+ print(f" Validation NMSE: {val_nmse:.4f}")
119
+ print(f" Learning Rate: {scheduler.get_last_lr()[0]:.6e}")
120
+
121
+ # Append to CSV log
122
+ with open(log_file, mode='a', newline='') as file:
123
+ writer = csv.writer(file)
124
+ writer.writerow([epoch + 1, train_nmse, val_nmse, scheduler.get_last_lr()[0], is_best_model])
125
+
126
+ # Plot losses after each epoch
127
+ plt.figure(figsize=(10, 6))
128
+ plt.plot(range(1, len(train_nmse_losses) + 1), train_nmse_losses, label="Train NMSE")
129
+ plt.plot(range(1, len(val_nmse_losses) + 1), val_nmse_losses, label="Validation NMSE")
130
+ plt.xlabel("Epochs")
131
+ plt.ylabel("NMSE")
132
+ plt.title("Training and Validation NMSE Loss")
133
+ plt.legend()
134
+ plt.grid(True)
135
+ plt.show()
136
+
137
+ print("Training and validation complete.")
138
+ return model
139
+ #%% FINE-TUNE
140
+ from torch.cuda.amp import GradScaler, autocast
141
+
142
+ # Define the ClassificationHead
143
+ class ClassificationHead(nn.Module):
144
+ def __init__(self, input_dim, num_classes):
145
+ super().__init__()
146
+ self.fc = nn.Linear(input_dim, num_classes)
147
+
148
+ def forward(self, x):
149
+ return self.fc(x)
150
+
151
+
152
+ # Define the RegressionHead
153
+ class RegressionHead(nn.Module):
154
+ def __init__(self, input_dim):
155
+ super().__init__()
156
+ self.fc = nn.Linear(input_dim, 1)
157
+
158
+ def forward(self, x):
159
+ return self.fc(x)
160
+
161
+ class CustomClassificationHead(nn.Module):
162
+ def __init__(self, input_dim, num_classes):
163
+
164
+ super().__init__()
165
+ self.classifier = nn.Sequential(
166
+ nn.Linear(input_dim, 512),
167
+ nn.BatchNorm1d(512),
168
+ nn.ReLU(),
169
+ nn.Dropout(0.1),
170
+ nn.Linear(512, 256),
171
+ nn.BatchNorm1d(256),
172
+ nn.ReLU(),
173
+ nn.Dropout(0.1),
174
+ nn.Linear(256, 128),
175
+ nn.BatchNorm1d(128),
176
+ nn.ReLU(),
177
+ # nn.Dropout(0.1),
178
+ nn.Linear(128, num_classes)
179
+ )
180
+
181
+ def forward(self, x):
182
+ return self.classifier(x)
183
+
184
+ class CustomRegressionHead(nn.Module):
185
+ def __init__(self, input_dim, output_dim):
186
+
187
+ super().__init__()
188
+ self.regressor = nn.Sequential(
189
+ nn.Linear(input_dim, 512),
190
+ nn.BatchNorm1d(512),
191
+ nn.ReLU(),
192
+ nn.Dropout(0.1),
193
+ nn.Linear(512, 256),
194
+ nn.BatchNorm1d(256),
195
+ nn.ReLU(),
196
+ nn.Dropout(0.1),
197
+ nn.Linear(256, output_dim)
198
+ )
199
+
200
+ def forward(self, x):
201
+ return self.regressor(x)
202
+
203
+
204
+ def custom_heads(input_dim, num_classes=None, output_dim=None, task_type="classification"):
205
+ """
206
+ Creates a custom head for classification or regression tasks.
207
+ Users should modify the class implementations for further customization.
208
+
209
+ Args:
210
+ input_dim (int): Input dimension of the head.
211
+ num_classes (int): Number of classes for classification tasks. Ignored for regression.
212
+ task_type (str): "classification" or "regression".
213
+
214
+ Returns:
215
+ nn.Module: Custom head for the specified task.
216
+ """
217
+ if task_type == "classification":
218
+ if num_classes is None:
219
+ raise ValueError("num_classes must be specified for classification tasks.")
220
+ return CustomClassificationHead(input_dim=input_dim, num_classes=num_classes)
221
+ elif task_type == "regression":
222
+ return CustomRegressionHead(input_dim=input_dim, output_dim=output_dim)
223
+ else:
224
+ raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
225
+ #%%
226
+ # Fine-tuning wrapper for the base model
227
+ class FineTuningWrapper(nn.Module):
228
+ def __init__(self, model, task_head, fine_tune_layers="full"):
229
+ super().__init__()
230
+ self.model = model
231
+ self.task_head = task_head
232
+
233
+ # Freeze all layers initially
234
+ for param in self.model.parameters():
235
+ param.requires_grad = False
236
+
237
+ # Handle fine-tuning layers
238
+ if fine_tune_layers is not None:
239
+ if fine_tune_layers == "full":
240
+ # Unfreeze all layers if "all" is specified
241
+ for param in self.model.parameters():
242
+ param.requires_grad = True
243
+ else:
244
+ # Get a list of all available layer names in the model
245
+ available_layers = [name for name, _ in self.model.named_parameters()]
246
+
247
+ # Validate that specified layers exist in the model
248
+ for layer in fine_tune_layers:
249
+ if not any(layer in lname for lname in available_layers):
250
+ raise ValueError(
251
+ f"Layer '{layer}' not found in the model. "
252
+ f"Available layers: {available_layers}"
253
+ )
254
+
255
+ # Unfreeze only the specified layers
256
+ for name, param in self.model.named_parameters():
257
+ if any(layer in name for layer in fine_tune_layers):
258
+ param.requires_grad = True
259
+
260
+ def forward(self, x, input_type="cls_emb"):
261
+ if input_type == "raw":
262
+ task_input = x.view(x.size(0), -1)
263
+ else:
264
+ embeddings, attn_maps = self.model(x) # Get embeddings from the base model
265
+ if input_type == "cls_emb":
266
+ task_input = embeddings[:, 0, :] # CLS token
267
+ elif input_type == "chs_emb":
268
+ chs_emb = embeddings[:, 1:, :]
269
+ task_input = chs_emb.view(chs_emb.size(0), -1) # embeddings.mean(dim=1) # Mean pooling over channel embeddings
270
+
271
+ return self.task_head(task_input), 0 if input_type=="raw" else attn_maps
272
+ #%%
273
+ # Fine-tuning function
274
+ from sklearn.metrics import f1_score
275
+ def finetune(
276
+ base_model,
277
+ train_loader,
278
+ val_loader=None,
279
+ task_type="classification",
280
+ input_type="cls_emb",
281
+ num_classes=None,
282
+ output_dim=None,
283
+ use_custom_head=False,
284
+ fine_tune_layers=None,
285
+ optimizer_config=None,
286
+ criterion=None,
287
+ epochs=10,
288
+ device="cuda",
289
+ task="Beam Prediction"
290
+ ):
291
+ """
292
+ Configures and fine-tunes the base model with user-defined settings, saving results and models.
293
+ """
294
+ # Create results folder
295
+ time_now = f"{time.time():.0f}"
296
+ results_folder = f"results/{task}/{time_now}"
297
+ os.makedirs(results_folder, exist_ok=True)
298
+ log_file = os.path.join(results_folder, "training_log.csv")
299
+
300
+ # Initialize the CSV log
301
+ with open(log_file, mode='w', newline='') as file:
302
+ writer = csv.writer(file)
303
+ writer.writerow(["Task", "Input", "Epoch", "Train Loss", "Validation Loss", "F1-Score (Classification)", "Learning Rate", "Time"])
304
+
305
+ for batch in val_loader:
306
+ input_data, targets = batch[0].to(device), batch[1].to(device)
307
+ break
308
+
309
+ if input_type == "cls_emb":
310
+ n_patches = 1
311
+ patch_size = 128
312
+ elif input_type == "channel_emb":
313
+ n_patches = input_data.shape[1]-1
314
+ patch_size = 128
315
+ elif input_type == "raw":
316
+ n_patches = input_data.shape[1]
317
+ patch_size = 32
318
+ # patch_size = 1
319
+
320
+ if use_custom_head:
321
+ custom_head = custom_heads(input_dim=n_patches*patch_size,
322
+ num_classes=num_classes,
323
+ output_dim=output_dim,
324
+ task_type=task_type)
325
+
326
+ # Handle DataParallel models
327
+ if isinstance(base_model, nn.DataParallel):
328
+ base_model = base_model.module
329
+
330
+ # Set up the task-specific head
331
+ if use_custom_head:
332
+ task_head = custom_head
333
+ elif task_type == "classification":
334
+ if num_classes is None:
335
+ raise ValueError("num_classes must be specified for classification tasks.")
336
+ task_head = ClassificationHead(input_dim=n_patches*patch_size, num_classes=num_classes) # input_dim=base_model.embedding.d_model
337
+ elif task_type == "regression":
338
+ task_head = RegressionHead(input_dim=n_patches*patch_size) # input_dim=base_model.embedding.d_model
339
+ else:
340
+ raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
341
+
342
+ # Wrap the model with the fine-tuning head
343
+ wrapper = FineTuningWrapper(base_model, task_head, fine_tune_layers=fine_tune_layers)
344
+ wrapper = wrapper.to(device)
345
+
346
+ print(f'Number of head parameters: {count_parameters(wrapper)}')
347
+
348
+ # Set default optimizer config if not provided
349
+ if optimizer_config is None:
350
+ optimizer_config = {"lr": 1e-4}
351
+ # Set up the optimizer
352
+ optimizer = torch.optim.Adam(wrapper.parameters(), **optimizer_config)
353
+ # Set up the scheduler for learning rate decay
354
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2) # Example: Reduce LR by 10x every 10 epochs
355
+
356
+ # Set up the loss criterion
357
+ if criterion is None:
358
+ criterion = nn.CrossEntropyLoss() if task_type == "classification" else nn.MSELoss()
359
+
360
+ scaler = GradScaler()
361
+ train_losses, val_losses, f1_scores = [], [], []
362
+ best_val_loss = float("inf")
363
+ best_model_path = None
364
+
365
+ for epoch in range(epochs):
366
+ # Training loop
367
+ wrapper.train()
368
+ epoch_loss = 0.0
369
+
370
+ with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as progress_bar:
371
+ for batch in progress_bar:
372
+ input_data, targets = batch[0].to(device), batch[1].to(device)
373
+ optimizer.zero_grad()
374
+
375
+ with autocast():
376
+ outputs, attn_maps = wrapper(input_data, input_type=input_type)
377
+ loss = criterion(outputs, targets)
378
+
379
+ scaler.scale(loss).backward()
380
+ scaler.step(optimizer)
381
+ scaler.update()
382
+
383
+ epoch_loss += loss.item()
384
+ progress_bar.set_postfix({"Loss": loss.item()})
385
+
386
+ avg_train_loss = epoch_loss / len(train_loader)
387
+ train_losses.append(avg_train_loss)
388
+
389
+ # Validation loop
390
+ if val_loader:
391
+ wrapper.eval()
392
+ val_loss = 0.0
393
+ all_preds, all_targets = [], []
394
+
395
+ with torch.no_grad():
396
+ for batch in val_loader:
397
+ input_data, targets = batch[0].to(device), batch[1].to(device)
398
+ with autocast():
399
+ outputs, _ = wrapper(input_data, input_type=input_type)
400
+ loss = criterion(outputs, targets)
401
+
402
+ val_loss += loss.item()
403
+
404
+ if task_type == "classification":
405
+ preds = torch.argmax(outputs, dim=1).cpu().numpy()
406
+ all_preds.extend(preds)
407
+ all_targets.extend(targets.cpu().numpy())
408
+
409
+ avg_val_loss = val_loss / len(val_loader)
410
+ val_losses.append(avg_val_loss)
411
+
412
+ time_now = f"{time.time():.0f}"
413
+ # Save the best model
414
+ if avg_val_loss < best_val_loss:
415
+ best_val_loss = avg_val_loss
416
+ best_model_path = os.path.join(results_folder, f"{input_type}_epoch{epoch+1}_valLoss{avg_val_loss:.4f}_{time_now}.pth")
417
+ torch.save(wrapper.state_dict(), best_model_path)
418
+ print(f"Model saved at {best_model_path} with validation loss: {best_val_loss:.4f}")
419
+
420
+ # Compute F1-score for classification tasks
421
+ f1 = None
422
+ if task_type == "classification":
423
+ f1 = f1_score(all_targets, all_preds, average="macro")
424
+ print(f"Epoch {epoch + 1}, Validation F1-Score: {f1:.4f}")
425
+ f1_scores.append(f1)
426
+
427
+ scheduler.step()
428
+
429
+ # Log results
430
+ with open(log_file, mode='a', newline='') as file:
431
+ writer = csv.writer(file)
432
+ writer.writerow([task, input_type, epoch + 1, avg_train_loss, avg_val_loss, f1 if f1 is not None else "-", scheduler.get_last_lr()[0], f"{time_now}"])
433
+
434
+ # Plot training and validation losses
435
+ plt.figure(figsize=(10, 6))
436
+ plt.plot(range(1, epochs + 1), train_losses, label="Training Loss")
437
+ plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss", linestyle="--")
438
+ plt.xlabel("Epochs")
439
+ plt.ylabel("Loss")
440
+ plt.title("Training and Validation Loss")
441
+ plt.legend()
442
+ plt.grid(True)
443
+ # plt.savefig(os.path.join(results_folder, "loss_curve.png"))
444
+ plt.show()
445
+
446
+ return wrapper, best_model_path, train_losses, val_losses, f1_scores if task_type == "classification" else 0, attn_maps
utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
2
+ import torch
3
+ import numpy as np
4
+ #%%
5
+ def create_dataloader(grouped_data, batch_size, shuffle):
6
+
7
+ dataloaders = {}
8
+
9
+ for seq_length, group in grouped_data.items():
10
+
11
+ print(f"dataloader in progress ...\nkey: {seq_length}")
12
+
13
+ ## Uncomment the following line if you run out of memory during pre-training
14
+ # batch_size = batch_size // 8 if seq_length >= 5 else batch_size
15
+
16
+ # Unpack samples for the current group
17
+ input_ids, masked_tokens, masked_pos = zip(*group)
18
+
19
+ # Convert to tensors
20
+ input_ids_tensor = torch.tensor(input_ids, dtype=torch.float32)
21
+ masked_tokens_tensor = torch.tensor(masked_tokens, dtype=torch.float32)
22
+ masked_pos_tensor = torch.tensor(masked_pos, dtype=torch.long)
23
+
24
+ # Create TensorDataset and DataLoader
25
+ dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
26
+ dataloaders[seq_length] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)
27
+
28
+ return dataloaders
29
+ #%%
30
+ def count_parameters(model):
31
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
32
+ #%%
33
+ import matplotlib.pyplot as plt
34
+ from sklearn.decomposition import PCA
35
+ from sklearn.manifold import TSNE
36
+ import umap
37
+
38
+ def visualize_embeddings(embeddings, labels, method="pca", label=None):
39
+ """
40
+ Visualize embeddings using PCA, UMAP, or t-SNE with color-coded labels.
41
+
42
+ Args:
43
+ embeddings (torch.Tensor or np.ndarray): Embeddings to visualize, shape (n_samples, n_features).
44
+ labels (torch.Tensor or np.ndarray): Class labels corresponding to embeddings, shape (n_samples,).
45
+ method (str): Dimensionality reduction method ('pca', 'umap', or 'tsne').
46
+ title (str): Title of the plot.
47
+ """
48
+ # Convert to numpy if input is a torch.Tensor
49
+ if isinstance(embeddings, torch.Tensor):
50
+ embeddings = embeddings.cpu().numpy()
51
+ if isinstance(labels, torch.Tensor):
52
+ labels = labels.cpu().numpy()
53
+
54
+ # Apply the selected dimensionality reduction method
55
+ if method.lower() == "pca":
56
+ reducer = PCA(n_components=2)
57
+ elif method.lower() == "umap":
58
+ reducer = umap.UMAP(n_components=2, n_neighbors=16, random_state=42)
59
+ elif method.lower() == "tsne":
60
+ reducer = TSNE(n_components=2, random_state=42, init="random")
61
+ else:
62
+ raise ValueError("Invalid method. Choose from 'pca', 'umap', or 'tsne'.")
63
+
64
+ reduced_embeddings = reducer.fit_transform(embeddings)
65
+
66
+ # Create a scatter plot with color-coding based on labels
67
+ plt.figure(figsize=(10, 8))
68
+ num_classes = len(np.unique(labels))
69
+ colors = plt.cm.get_cmap("tab10", num_classes)
70
+
71
+ for class_idx in range(num_classes):
72
+ class_points = reduced_embeddings[labels == class_idx]
73
+ plt.scatter(
74
+ class_points[:, 0], class_points[:, 1],
75
+ label=f"Class {class_idx}",
76
+ alpha=0.6
77
+ )
78
+
79
+ # Customize the plot
80
+ plt.title(f"{label} ({method.upper()})")
81
+ plt.xlabel("Component 1")
82
+ plt.ylabel("Component 2")
83
+ plt.legend()
84
+ plt.show()
85
+ #%%
86
+ def generate_gaussian_noise(data, snr_db):
87
+ """
88
+ Generate Gaussian noise given an SNR and apply it to the data.
89
+
90
+ Args:
91
+ data (torch.Tensor): Input data tensor of shape (n_samples, seq_len, feature_dim).
92
+ snr_db (float): Signal-to-Noise Ratio in decibels (dB).
93
+
94
+ Returns:
95
+ torch.Tensor: Data with Gaussian noise applied.
96
+ """
97
+ # Separate the input data to exclude the first channel
98
+ a = data[:, 1:, :] # Shape: (n_samples, seq_len-1, feature_dim)
99
+ flat_data = a.view(a.size(0), -1) # Flatten data to calculate power
100
+ signal_power = torch.mean(flat_data**2, dim=1, keepdim=True) # Shape: (n_samples, 1)
101
+ snr_linear = 10 ** (snr_db / 10)
102
+ noise_power = signal_power / snr_linear
103
+ noise = torch.randn_like(flat_data) * torch.sqrt(noise_power)
104
+ noise = noise.view_as(a)
105
+ noise = torch.cat((torch.zeros_like(data[:, :1, :]), noise), dim=1) # Add zero noise for the first channel
106
+
107
+ return noise
108
+ #%%
109
+ def plot_coverage(rxs, cov_map, dpi=200, figsize=(6,4), cbar_title=None, title=False,
110
+ scat_sz=.5, tx_pos=None, tx_ori=None, legend=False, lims=None,
111
+ proj_3D=False, equal_aspect=False, tight=True, cmap='tab20'):
112
+
113
+ plt_params = {'cmap': cmap}
114
+ if lims:
115
+ plt_params['vmin'], plt_params['vmax'] = lims[0], lims[1]
116
+
117
+ n = 3 if proj_3D else 2 # n coordinates to consider 2 = xy | 3 = xyz
118
+
119
+ xyz = {'x': rxs[:,0], 'y': rxs[:,1]}
120
+ if proj_3D:
121
+ xyz['zs'] = rxs[:,2]
122
+
123
+ fig, ax = plt.subplots(dpi=dpi, figsize=figsize,
124
+ subplot_kw={'projection': '3d'} if proj_3D else {})
125
+
126
+ im = plt.scatter(**xyz, c=cov_map, s=scat_sz, marker='s', **plt_params)
127
+
128
+ cbar = plt.colorbar(im, label='' if not cbar_title else cbar_title)
129
+
130
+ plt.xlabel('x (m)')
131
+ plt.ylabel('y (m)')
132
+
133
+ # TX position
134
+ if tx_pos is not None:
135
+ ax.scatter(*tx_pos[:n], marker='P', c='r', label='TX')
136
+
137
+ # TX orientation
138
+ if tx_ori is not None and tx_pos is not None: # ori = [azi, el]
139
+ # positive azimuths point left (like positive angles in a unit circle)
140
+ # positive elevations point up
141
+ r = 30 # ref size of pointing direction
142
+ tx_lookat = np.copy(tx_pos)
143
+ tx_lookat[:2] += r * np.array([np.cos(tx_ori[2]), np.sin(tx_ori[2])]) # azimuth
144
+ tx_lookat[2] += r * np.sin(tx_ori[1]) # elevation
145
+
146
+ line_components = [[tx_pos[i], tx_lookat[i]] for i in range(n)]
147
+ line = {key:val for key,val in zip(['xs', 'ys', 'zs'], line_components)}
148
+ if n == 2:
149
+ ax.plot(line_components[0], line_components[1], c='k', alpha=.5, zorder=3)
150
+ else:
151
+ ax.plot(**line, c='k', alpha=.5, zorder=3)
152
+
153
+ if title:
154
+ ax.set_title(title)
155
+
156
+ if legend:
157
+ plt.legend(loc='upper center', ncols=10, framealpha=.5)
158
+
159
+ if tight:
160
+ s = 1
161
+ mins, maxs = np.min(rxs, axis=0)-s, np.max(rxs, axis=0)+s
162
+ if not proj_3D:
163
+ plt.xlim([mins[0], maxs[0]])
164
+ plt.ylim([mins[1], maxs[1]])
165
+ else:
166
+ ax.axes.set_xlim3d([mins[0], maxs[0]])
167
+ ax.axes.set_ylim3d([mins[1], maxs[1]])
168
+ if tx_pos is None:
169
+ ax.axes.set_zlim3d([mins[2], maxs[2]])
170
+ else:
171
+ ax.axes.set_zlim3d([np.min([mins[2], tx_pos[2]]),
172
+ np.max([mins[2], tx_pos[2]])])
173
+
174
+ if equal_aspect and not proj_3D: # disrups the plot
175
+ plt.axis('scaled')
176
+
177
+ return fig, ax, cbar
178
+ #%%
179
+ def prepare_loaders(
180
+ preprocessed_data,
181
+ labels=None,
182
+ selected_patches_idxs=None,
183
+ input_type="raw",
184
+ task_type="classification",
185
+ feature_selection=False,
186
+ train_ratio=0.8,
187
+ batch_size=64,
188
+ seed=42 # Default seed for reproducibility
189
+ ):
190
+ """
191
+ Prepares datasets and data loaders for training and validation.
192
+
193
+ Args:
194
+ preprocessed_data (torch.Tensor): The input data, either raw or preprocessed.
195
+ labels (torch.Tensor, optional): The labels for classification tasks.
196
+ selected_patches_idxs (torch.Tensor, optional): Indices of selected patches for feature selection.
197
+ input_type (str): "raw" or "processed" to specify input data type.
198
+ task_type (str): "classification" or "regression".
199
+ feature_selection (bool): Whether to perform feature selection based on selected_patches_idxs.
200
+ train_ratio (float): Proportion of data to use for training (remaining for validation).
201
+ batch_size (int): Batch size for data loaders.
202
+ seed (int): Random seed for reproducibility.
203
+
204
+ Returns:
205
+ tuple: (train_loader, val_loader)
206
+ """
207
+ # Set random seed for reproducibility
208
+ torch.manual_seed(seed)
209
+
210
+ # Prepare samples
211
+ if input_type == "raw":
212
+ if feature_selection and selected_patches_idxs is not None:
213
+ batch_indices = torch.arange(preprocessed_data.size(0)).unsqueeze(1) # Shape: [batch_size, 1]
214
+ samples = torch.tensor(preprocessed_data[batch_indices, selected_patches_idxs], dtype=torch.float32)
215
+ else:
216
+ samples = torch.tensor(preprocessed_data[:, 1:], dtype=torch.float32) # raw_chs
217
+ else:
218
+ samples = torch.tensor(preprocessed_data, dtype=torch.float32)
219
+
220
+ # Prepare dataset
221
+ if task_type == "classification":
222
+ if labels is None:
223
+ raise ValueError("Labels are required for classification tasks.")
224
+ labels = torch.tensor(labels, dtype=torch.long)
225
+ dataset = TensorDataset(samples, labels)
226
+ target = 0 # REVISE if needed
227
+ elif task_type == "regression":
228
+ target = samples[:, 1:, :].view(samples.size(0), -1) # Reshape for regression targets
229
+ dataset = TensorDataset(samples, target)
230
+ else:
231
+ raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
232
+
233
+ # Set random seed for reproducibility
234
+ generator = torch.Generator().manual_seed(seed)
235
+
236
+ # Split dataset into training and validation
237
+ n_samples = len(dataset)
238
+ train_size = int(train_ratio * n_samples)
239
+ val_size = n_samples - train_size
240
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
241
+
242
+ # Create DataLoaders
243
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
244
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
245
+
246
+ print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")
247
+ return train_loader, val_loader, samples, target