wi-lab commited on
Commit
91e1a50
·
verified ·
1 Parent(s): 5b7442b

Upload downstream.py

Browse files
Files changed (1) hide show
  1. downstream.py +146 -145
downstream.py CHANGED
@@ -1,146 +1,147 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Fri Jan 10 11:11:58 2025
4
-
5
- This script evaluates downstream task performance by comparing models trained
6
- on raw channel representations versus those trained on LWM embeddings.
7
-
8
- @author: Sadjad Alikhani
9
- """
10
- #%% IMPORT PACKAGES & MODULES
11
- from input_preprocess import tokenizer, scenarios_list
12
- from inference import lwm_inference
13
- from utils import prepare_loaders
14
- from train import finetune
15
- import lwm_model
16
- import matplotlib.pyplot as plt
17
- import numpy as np
18
- import torch
19
- import torch.nn as nn
20
- import warnings
21
- warnings.filterwarnings("ignore", category=UserWarning)
22
- #%% DOWNSTERAM DATA GENERATION
23
- n_beams = 16
24
- task = ['Beam Prediction', 'LoS/NLoS Classification'][1]
25
- task_type = ["classification", "regression"][0]
26
- visualization_method = ["pca", "umap", "tsne"][2]
27
- input_types = ["cls_emb", "channel_emb", "raw"]
28
- train_ratios = [.001, .01, .05, .1, .25, .5, .8]
29
- fine_tuning_status = [None, ["layers.8", "layers.9", "layers.10", "layers.11"], "full"]
30
- selected_scenario_names = [scenarios_list()[6]]
31
- preprocessed_data, labels, raw_chs = tokenizer(
32
- selected_scenario_names,
33
- bs_idxs=[3],
34
- load_data=False,
35
- task=task,
36
- n_beams=n_beams)
37
- #%% LOAD THE MODEL
38
- gpu_ids = [0]
39
- device = torch.device("cuda:0")
40
- model = lwm_model.lwm().to(device)
41
-
42
- model_name = "model.pth"
43
- state_dict = torch.load(f"models/{model_name}", map_location=device)
44
- new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
45
- model.load_state_dict(new_state_dict)
46
-
47
- model = nn.DataParallel(model, gpu_ids)
48
- print(f"Model loaded successfully on GPU {device.index}")
49
- #%% 2D EMBEDDING SPACE VISUALIZATIONN BEFORE FINE-TUNING
50
- chs = lwm_inference(
51
- model,
52
- preprocessed_data,
53
- input_type="cls_emb",
54
- device=device,
55
- batch_size=64,
56
- visualization=False,
57
- labels=labels,
58
- visualization_method=visualization_method)
59
- #%% FINE-TUNE
60
- results = np.zeros((len(fine_tuning_status), len(input_types), len(train_ratios)))
61
- for fine_tuning_stat_idx, fine_tuning_stat in enumerate(fine_tuning_status):
62
- for input_type_idx, input_type in enumerate(input_types):
63
-
64
- if input_type == "raw" and fine_tuning_stat is not None:
65
- continue
66
-
67
- selected_patches_idxs = None
68
- for train_ratio_idx, train_ratio in enumerate(train_ratios):
69
-
70
- print(f"\nfine-tuning status: {fine_tuning_stat}")
71
- print(f"input type: {input_type}")
72
- print(f"train ratio: {train_ratio}\n")
73
-
74
- # PREPARE LOADERS
75
- train_loader, val_loader, samples, target = prepare_loaders(
76
- preprocessed_data=preprocessed_data,
77
- labels=labels,
78
- selected_patches_idxs=selected_patches_idxs,
79
- input_type=input_type,
80
- task_type=task_type,
81
- train_ratio=train_ratio,
82
- batch_size=128,
83
- seed=42
84
- )
85
-
86
- # FINE-TUNE LWM
87
- fine_tuned_model, best_model_path, train_losses, val_losses, f1_scores, attn_maps_ft = finetune(
88
- base_model=model,
89
- train_loader=train_loader,
90
- val_loader=val_loader,
91
- task_type=task_type,
92
- input_type=input_type,
93
- num_classes=n_beams if task=='Beam Prediction' else 2 if task=='LoS/NLoS Classification' else None,
94
- output_dim=target.shape[-1] if task_type =='regression' else None,
95
- use_custom_head=True,
96
- fine_tune_layers=fine_tuning_stat,
97
- optimizer_config={"lr": 1e-3},
98
- epochs=15,
99
- device=device,
100
- task=task
101
- )
102
-
103
- results[fine_tuning_stat_idx][input_type_idx][train_ratio_idx] = f1_scores[-1]
104
-
105
- markers = ['o', 's', 'D']
106
- labels = ['CLS Emb', 'CHS Emb', 'Raw']
107
- fine_tuning_status_labels = ['No FT', 'Partial FT', 'Full FT']
108
- line_styles = ['-', '--', ':']
109
- colors = plt.cm.viridis(np.linspace(0, 0.8, len(labels)))
110
- plt.figure(figsize=(12, 8), dpi=500)
111
- for ft_idx, (ft_status_label, line_style) in enumerate(zip(fine_tuning_status_labels, line_styles)):
112
- for idx, (marker, label, color) in enumerate(zip(markers, labels, colors)):
113
- # For "Raw Channels," only plot "No Fine-Tuning" case
114
- if label == "Raw" and ft_status_label != "No FT":
115
- continue
116
- # Simplify label for "Raw Channels" without fine-tuning
117
- plot_label = label if label != "Raw Channels" or ft_status_label != "No Fine-Tuning" else "Raw Channels"
118
- plt.plot(
119
- train_ratios,
120
- results[ft_idx, idx],
121
- marker=marker,
122
- linestyle=line_style,
123
- label=f"{plot_label} ({ft_status_label})" if label != "Raw Channels" else plot_label,
124
- color=color,
125
- linewidth=3,
126
- markersize=9
127
- )
128
- plt.xscale('log')
129
- plt.xlabel("Train Ratio", fontsize=20)
130
- plt.ylabel("F1-Score", fontsize=20)
131
- plt.legend(fontsize=17, loc="best")
132
- plt.grid(True, linestyle="--", alpha=0.7)
133
- plt.xticks(fontsize=17)
134
- plt.yticks(fontsize=17)
135
- plt.tight_layout()
136
- plt.show()
137
- #%% 2D EMBEDDING SPACE VISUALIZATIONN AFTER FINE-TUNING
138
- chs = lwm_inference(
139
- fine_tuned_model.model,
140
- preprocessed_data,
141
- input_type="cls_emb",
142
- device=device,
143
- batch_size=64,
144
- visualization=False,
145
- labels=labels,
 
146
  visualization_method=visualization_method)
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Jan 10 11:11:58 2025
4
+
5
+ This script evaluates downstream task performance by comparing models trained
6
+ on raw channel representations versus those trained on LWM embeddings.
7
+
8
+ @author: Sadjad Alikhani
9
+ """
10
+ #%% IMPORT PACKAGES & MODULES
11
+ from input_preprocess import tokenizer, scenarios_list
12
+ from inference import lwm_inference
13
+ from utils import prepare_loaders
14
+ from train import finetune
15
+ import lwm_model
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import warnings
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ #%% DOWNSTERAM DATA GENERATION
23
+ n_beams = 16
24
+ task = ['Beam Prediction', 'LoS/NLoS Classification'][1]
25
+ task_type = ["classification", "regression"][0]
26
+ visualization_method = ["pca", "umap", "tsne"][2]
27
+ input_types = ["cls_emb", "channel_emb", "raw"]
28
+ train_ratios = [.001, .01, .05, .1, .25, .5, .8]
29
+ fine_tuning_status = [None, ["layers.8", "layers.9", "layers.10", "layers.11"], "full"]
30
+ selected_scenario_names = [scenarios_list()[6]]
31
+ preprocessed_data, labels, raw_chs = tokenizer(
32
+ selected_scenario_names,
33
+ bs_idxs=[3],
34
+ load_data=False,
35
+ task=task,
36
+ n_beams=n_beams,
37
+ manual_data=None)
38
+ #%% LOAD THE MODEL
39
+ gpu_ids = [0]
40
+ device = torch.device("cuda:0")
41
+ model = lwm_model.lwm().to(device)
42
+
43
+ model_name = "model.pth"
44
+ state_dict = torch.load(f"models/{model_name}", map_location=device)
45
+ new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
46
+ model.load_state_dict(new_state_dict)
47
+
48
+ model = nn.DataParallel(model, gpu_ids)
49
+ print(f"Model loaded successfully on GPU {device.index}")
50
+ #%% 2D EMBEDDING SPACE VISUALIZATIONN BEFORE FINE-TUNING
51
+ chs = lwm_inference(
52
+ model,
53
+ preprocessed_data,
54
+ input_type="cls_emb",
55
+ device=device,
56
+ batch_size=64,
57
+ visualization=False,
58
+ labels=labels,
59
+ visualization_method=visualization_method)
60
+ #%% FINE-TUNE
61
+ results = np.zeros((len(fine_tuning_status), len(input_types), len(train_ratios)))
62
+ for fine_tuning_stat_idx, fine_tuning_stat in enumerate(fine_tuning_status):
63
+ for input_type_idx, input_type in enumerate(input_types):
64
+
65
+ if input_type == "raw" and fine_tuning_stat is not None:
66
+ continue
67
+
68
+ selected_patches_idxs = None
69
+ for train_ratio_idx, train_ratio in enumerate(train_ratios):
70
+
71
+ print(f"\nfine-tuning status: {fine_tuning_stat}")
72
+ print(f"input type: {input_type}")
73
+ print(f"train ratio: {train_ratio}\n")
74
+
75
+ # PREPARE LOADERS
76
+ train_loader, val_loader, samples, target = prepare_loaders(
77
+ preprocessed_data=preprocessed_data,
78
+ labels=labels,
79
+ selected_patches_idxs=selected_patches_idxs,
80
+ input_type=input_type,
81
+ task_type=task_type,
82
+ train_ratio=train_ratio,
83
+ batch_size=128,
84
+ seed=42
85
+ )
86
+
87
+ # FINE-TUNE LWM
88
+ fine_tuned_model, best_model_path, train_losses, val_losses, f1_scores, attn_maps_ft = finetune(
89
+ base_model=model,
90
+ train_loader=train_loader,
91
+ val_loader=val_loader,
92
+ task_type=task_type,
93
+ input_type=input_type,
94
+ num_classes=n_beams if task=='Beam Prediction' else 2 if task=='LoS/NLoS Classification' else None,
95
+ output_dim=target.shape[-1] if task_type =='regression' else None,
96
+ use_custom_head=True,
97
+ fine_tune_layers=fine_tuning_stat,
98
+ optimizer_config={"lr": 1e-3},
99
+ epochs=15,
100
+ device=device,
101
+ task=task
102
+ )
103
+
104
+ results[fine_tuning_stat_idx][input_type_idx][train_ratio_idx] = f1_scores[-1]
105
+
106
+ markers = ['o', 's', 'D']
107
+ labels = ['CLS Emb', 'CHS Emb', 'Raw']
108
+ fine_tuning_status_labels = ['No FT', 'Partial FT', 'Full FT']
109
+ line_styles = ['-', '--', ':']
110
+ colors = plt.cm.viridis(np.linspace(0, 0.8, len(labels)))
111
+ plt.figure(figsize=(12, 8), dpi=500)
112
+ for ft_idx, (ft_status_label, line_style) in enumerate(zip(fine_tuning_status_labels, line_styles)):
113
+ for idx, (marker, label, color) in enumerate(zip(markers, labels, colors)):
114
+ # For "Raw Channels," only plot "No Fine-Tuning" case
115
+ if label == "Raw" and ft_status_label != "No FT":
116
+ continue
117
+ # Simplify label for "Raw Channels" without fine-tuning
118
+ plot_label = label if label != "Raw Channels" or ft_status_label != "No Fine-Tuning" else "Raw Channels"
119
+ plt.plot(
120
+ train_ratios,
121
+ results[ft_idx, idx],
122
+ marker=marker,
123
+ linestyle=line_style,
124
+ label=f"{plot_label} ({ft_status_label})" if label != "Raw Channels" else plot_label,
125
+ color=color,
126
+ linewidth=3,
127
+ markersize=9
128
+ )
129
+ plt.xscale('log')
130
+ plt.xlabel("Train Ratio", fontsize=20)
131
+ plt.ylabel("F1-Score", fontsize=20)
132
+ plt.legend(fontsize=17, loc="best")
133
+ plt.grid(True, linestyle="--", alpha=0.7)
134
+ plt.xticks(fontsize=17)
135
+ plt.yticks(fontsize=17)
136
+ plt.tight_layout()
137
+ plt.show()
138
+ #%% 2D EMBEDDING SPACE VISUALIZATIONN AFTER FINE-TUNING
139
+ chs = lwm_inference(
140
+ fine_tuned_model.model,
141
+ preprocessed_data,
142
+ input_type="cls_emb",
143
+ device=device,
144
+ batch_size=64,
145
+ visualization=False,
146
+ labels=labels,
147
  visualization_method=visualization_method)