File size: 3,114 Bytes
287a683 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import numpy as np
import torch
import torch.nn as nn
from scripts.aggregator import Perceiver
from scripts.dataset import WSIPatchDataset
from scripts.feature_extractor import vit_base
from scripts.wsi_utils import extract_tissue_patch_coords
from torch.utils.data import DataLoader
class EXAONEPathV1Downstream(nn.Module):
def __init__(
self, device: torch.device, step_size=256, patch_size=256, macenko=True
):
super(EXAONEPathV1Downstream, self).__init__()
self.step_size = step_size
self.patch_size = patch_size
self.macenko = macenko
self.device = device
self.feature_extractor = vit_base()
self.feature_extractor = self.feature_extractor
self.feature_extractor = self.feature_extractor.to(self.device)
self.feature_extractor.eval()
self.agg_model = Perceiver(
input_channels=768,
input_axis=1,
num_freq_bands=6,
max_freq=10.0,
depth=6,
num_latents=256,
latent_dim=512,
cross_heads=1,
latent_heads=8,
cross_dim_head=64,
latent_dim_head=64,
num_classes=2,
fourier_encode_data=False,
self_per_cross_attn=2,
pool="mean",
)
self.agg_model.to(self.device)
self.agg_model.eval()
@torch.no_grad()
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
# Extract patches
coords = extract_tissue_patch_coords(
svs_path, patch_size=self.patch_size, step_size=self.step_size
)
# Extract patch-level features
self.feature_extractor.eval()
patch_dataset = WSIPatchDataset(
coords=coords,
wsi_path=svs_path,
pretrained=True,
macenko=self.macenko,
patch_size=self.patch_size,
)
patch_loader = DataLoader(
dataset=patch_dataset,
batch_size=feature_extractor_batch_size,
num_workers=(
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
),
pin_memory=self.device.type == "cuda",
)
features_list = []
for count, patches in enumerate(patch_loader):
print(
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
end="\r",
)
patches = patches.to(self.device, non_blocking=True)
feature = self.feature_extractor(patches) # [B, 1024]
feature /= feature.norm(dim=-1, keepdim=True) # use normalized feature
feature = feature.to("cpu", non_blocking=True)
features_list.append(feature)
print("")
print("Feature extraction finished")
features = torch.cat(features_list)
# Aggregate features
self.agg_model.eval()
logits, Y_prob, Y_hat = self.agg_model(features[None].to(self.device))
probs = Y_prob[0].cpu()
return probs
|