File size: 1,980 Bytes
8920c6e
 
 
 
 
 
 
 
 
 
 
 
 
f73f011
8920c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 15 18:27:17 2024

This scripts performs the LWM inference on raw channel representations.

@author: Sadjad Alikhani
"""
import torch
from torch.utils.data import DataLoader, TensorDataset
from utils import visualize_embeddings
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')
#%%
def lwm_inference(model, data, input_type="cls_emb", device="cpu", batch_size=64, visualization=False, labels=None, visualization_method="t-sne"):
    
    if input_type == "raw":
        output_total = data
    else:
        dataset = TensorDataset(data)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        
        embeddings = []
        model.eval()
        with torch.no_grad():
            with tqdm(dataloader, desc="Inference", unit="batch") as t:
                for batch in t:
                    
                    input_ids = batch[0].to(device)
                    output = model(input_ids)[0]
                    
                    if input_type == "cls_emb":
                        batch_embeddings = output[:, 0, :] 
                        embeddings.append(batch_embeddings)
                    elif input_type == "channel_emb":
                        batch_embeddings = output[:, 1:, :] 
                        embeddings.append(batch_embeddings)
                        
        output_total = torch.cat(embeddings, dim=0).float()
        
        if visualization:
            visualize_embeddings(output_total.view(output_total.size(0), -1), 
                                 labels, 
                                 method=visualization_method, 
                                 label="Embedding Space")
            visualize_embeddings(data.view(data.size(0), -1), 
                                 labels, 
                                 method=visualization_method, 
                                 label="Original Space")
        
    return output_total