File size: 5,614 Bytes
91e1a50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8920c6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# -*- coding: utf-8 -*-
"""

Created on Fri Jan 10 11:11:58 2025



This script evaluates downstream task performance by comparing models trained 

on raw channel representations versus those trained on LWM embeddings.



@author: Sadjad Alikhani

"""
#%% IMPORT PACKAGES & MODULES
from input_preprocess import tokenizer, scenarios_list
from inference import lwm_inference
from utils import prepare_loaders
from train import finetune
import lwm_model
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
#%% DOWNSTERAM DATA GENERATION
n_beams = 16
task = ['Beam Prediction', 'LoS/NLoS Classification'][1]
task_type = ["classification", "regression"][0]
visualization_method = ["pca", "umap", "tsne"][2]
input_types = ["cls_emb", "channel_emb", "raw"]
train_ratios = [.001, .01, .05, .1, .25, .5, .8]
fine_tuning_status = [None, ["layers.8", "layers.9", "layers.10", "layers.11"], "full"]
selected_scenario_names = [scenarios_list()[6]] 
preprocessed_data, labels, raw_chs = tokenizer(
    selected_scenario_names, 
    bs_idxs=[3], 
    load_data=False, 
    task=task, 
    n_beams=n_beams,
    manual_data=None)
#%% LOAD THE MODEL
gpu_ids = [0]
device = torch.device("cuda:0")
model = lwm_model.lwm().to(device)

model_name = "model.pth"
state_dict = torch.load(f"models/{model_name}", map_location=device)
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

model = nn.DataParallel(model, gpu_ids)
print(f"Model loaded successfully on GPU {device.index}")
#%% 2D EMBEDDING SPACE VISUALIZATIONN BEFORE FINE-TUNING
chs = lwm_inference(
    model, 
    preprocessed_data, 
    input_type="cls_emb", 
    device=device, 
    batch_size=64, 
    visualization=False, 
    labels=labels, 
    visualization_method=visualization_method)
#%% FINE-TUNE
results = np.zeros((len(fine_tuning_status), len(input_types), len(train_ratios)))
for fine_tuning_stat_idx, fine_tuning_stat in enumerate(fine_tuning_status):
    for input_type_idx, input_type in enumerate(input_types):
        
        if input_type == "raw" and fine_tuning_stat is not None:
            continue
        
        selected_patches_idxs = None
        for train_ratio_idx, train_ratio in enumerate(train_ratios):
            
            print(f"\nfine-tuning status: {fine_tuning_stat}")
            print(f"input type: {input_type}")
            print(f"train ratio: {train_ratio}\n")
            
            # PREPARE LOADERS
            train_loader, val_loader, samples, target = prepare_loaders(
                preprocessed_data=preprocessed_data,
                labels=labels,
                selected_patches_idxs=selected_patches_idxs,
                input_type=input_type,
                task_type=task_type,
                train_ratio=train_ratio,
                batch_size=128,
                seed=42
            )
            
            # FINE-TUNE LWM 
            fine_tuned_model, best_model_path, train_losses, val_losses, f1_scores, attn_maps_ft = finetune(
                base_model=model, 
                train_loader=train_loader,
                val_loader=val_loader,
                task_type=task_type,
                input_type=input_type,
                num_classes=n_beams if task=='Beam Prediction' else 2 if task=='LoS/NLoS Classification' else None,
                output_dim=target.shape[-1] if task_type =='regression' else None,
                use_custom_head=True,
                fine_tune_layers=fine_tuning_stat,
                optimizer_config={"lr": 1e-3},
                epochs=15,
                device=device,
                task=task
            )
            
            results[fine_tuning_stat_idx][input_type_idx][train_ratio_idx] = f1_scores[-1]

markers = ['o', 's', 'D'] 
labels = ['CLS Emb', 'CHS Emb', 'Raw']
fine_tuning_status_labels = ['No FT', 'Partial FT', 'Full FT']
line_styles = ['-', '--', ':'] 
colors = plt.cm.viridis(np.linspace(0, 0.8, len(labels)))
plt.figure(figsize=(12, 8), dpi=500)
for ft_idx, (ft_status_label, line_style) in enumerate(zip(fine_tuning_status_labels, line_styles)):
    for idx, (marker, label, color) in enumerate(zip(markers, labels, colors)):
        # For "Raw Channels," only plot "No Fine-Tuning" case
        if label == "Raw" and ft_status_label != "No FT":
            continue
        # Simplify label for "Raw Channels" without fine-tuning
        plot_label = label if label != "Raw Channels" or ft_status_label != "No Fine-Tuning" else "Raw Channels"
        plt.plot(
            train_ratios, 
            results[ft_idx, idx], 
            marker=marker, 
            linestyle=line_style, 
            label=f"{plot_label} ({ft_status_label})" if label != "Raw Channels" else plot_label, 
            color=color, 
            linewidth=3, 
            markersize=9
        )
plt.xscale('log')
plt.xlabel("Train Ratio", fontsize=20)
plt.ylabel("F1-Score", fontsize=20)
plt.legend(fontsize=17, loc="best")
plt.grid(True, linestyle="--", alpha=0.7)
plt.xticks(fontsize=17)
plt.yticks(fontsize=17)
plt.tight_layout()
plt.show()
#%% 2D EMBEDDING SPACE VISUALIZATIONN AFTER FINE-TUNING
chs = lwm_inference(
    fine_tuned_model.model, 
    preprocessed_data, 
    input_type="cls_emb", 
    device=device, 
    batch_size=64, 
    visualization=False, 
    labels=labels, 
    visualization_method=visualization_method)