File size: 1,980 Bytes
8920c6e f73f011 8920c6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 15 18:27:17 2024
This scripts performs the LWM inference on raw channel representations.
@author: Sadjad Alikhani
"""
import torch
from torch.utils.data import DataLoader, TensorDataset
from utils import visualize_embeddings
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
#%%
def lwm_inference(model, data, input_type="cls_emb", device="cpu", batch_size=64, visualization=False, labels=None, visualization_method="t-sne"):
if input_type == "raw":
output_total = data
else:
dataset = TensorDataset(data)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
embeddings = []
model.eval()
with torch.no_grad():
with tqdm(dataloader, desc="Inference", unit="batch") as t:
for batch in t:
input_ids = batch[0].to(device)
output = model(input_ids)[0]
if input_type == "cls_emb":
batch_embeddings = output[:, 0, :]
embeddings.append(batch_embeddings)
elif input_type == "channel_emb":
batch_embeddings = output[:, 1:, :]
embeddings.append(batch_embeddings)
output_total = torch.cat(embeddings, dim=0).float()
if visualization:
visualize_embeddings(output_total.view(output_total.size(0), -1),
labels,
method=visualization_method,
label="Embedding Space")
visualize_embeddings(data.view(data.size(0), -1),
labels,
method=visualization_method,
label="Original Space")
return output_total
|