|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from scripts.aggregator import Perceiver |
|
from scripts.dataset import WSIPatchDataset |
|
from scripts.feature_extractor import vit_base |
|
from scripts.wsi_utils import extract_tissue_patch_coords |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
class EXAONEPathV1Downstream(nn.Module): |
|
def __init__( |
|
self, device: torch.device, step_size=256, patch_size=256, macenko=True |
|
): |
|
super(EXAONEPathV1Downstream, self).__init__() |
|
self.step_size = step_size |
|
self.patch_size = patch_size |
|
self.macenko = macenko |
|
self.device = device |
|
|
|
self.feature_extractor = vit_base() |
|
self.feature_extractor = self.feature_extractor |
|
self.feature_extractor = self.feature_extractor.to(self.device) |
|
self.feature_extractor.eval() |
|
|
|
self.agg_model = Perceiver( |
|
input_channels=768, |
|
input_axis=1, |
|
num_freq_bands=6, |
|
max_freq=10.0, |
|
depth=6, |
|
num_latents=256, |
|
latent_dim=512, |
|
cross_heads=1, |
|
latent_heads=8, |
|
cross_dim_head=64, |
|
latent_dim_head=64, |
|
num_classes=2, |
|
fourier_encode_data=False, |
|
self_per_cross_attn=2, |
|
pool="mean", |
|
) |
|
self.agg_model.to(self.device) |
|
self.agg_model.eval() |
|
|
|
@torch.no_grad() |
|
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8): |
|
|
|
coords = extract_tissue_patch_coords( |
|
svs_path, patch_size=self.patch_size, step_size=self.step_size |
|
) |
|
|
|
|
|
self.feature_extractor.eval() |
|
patch_dataset = WSIPatchDataset( |
|
coords=coords, |
|
wsi_path=svs_path, |
|
pretrained=True, |
|
macenko=self.macenko, |
|
patch_size=self.patch_size, |
|
) |
|
patch_loader = DataLoader( |
|
dataset=patch_dataset, |
|
batch_size=feature_extractor_batch_size, |
|
num_workers=( |
|
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 |
|
), |
|
pin_memory=self.device.type == "cuda", |
|
) |
|
features_list = [] |
|
for count, patches in enumerate(patch_loader): |
|
print( |
|
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", |
|
end="\r", |
|
) |
|
patches = patches.to(self.device, non_blocking=True) |
|
|
|
feature = self.feature_extractor(patches) |
|
feature /= feature.norm(dim=-1, keepdim=True) |
|
feature = feature.to("cpu", non_blocking=True) |
|
features_list.append(feature) |
|
print("") |
|
print("Feature extraction finished") |
|
|
|
features = torch.cat(features_list) |
|
|
|
|
|
self.agg_model.eval() |
|
logits, Y_prob, Y_hat = self.agg_model(features[None].to(self.device)) |
|
probs = Y_prob[0].cpu() |
|
|
|
return probs |
|
|