편명장/님/(myeongjang.pyeon)
initial commit
287a683
import numpy as np
import torch
import torch.nn as nn
from scripts.aggregator import Perceiver
from scripts.dataset import WSIPatchDataset
from scripts.feature_extractor import vit_base
from scripts.wsi_utils import extract_tissue_patch_coords
from torch.utils.data import DataLoader
class EXAONEPathV1Downstream(nn.Module):
def __init__(
self, device: torch.device, step_size=256, patch_size=256, macenko=True
):
super(EXAONEPathV1Downstream, self).__init__()
self.step_size = step_size
self.patch_size = patch_size
self.macenko = macenko
self.device = device
self.feature_extractor = vit_base()
self.feature_extractor = self.feature_extractor
self.feature_extractor = self.feature_extractor.to(self.device)
self.feature_extractor.eval()
self.agg_model = Perceiver(
input_channels=768,
input_axis=1,
num_freq_bands=6,
max_freq=10.0,
depth=6,
num_latents=256,
latent_dim=512,
cross_heads=1,
latent_heads=8,
cross_dim_head=64,
latent_dim_head=64,
num_classes=2,
fourier_encode_data=False,
self_per_cross_attn=2,
pool="mean",
)
self.agg_model.to(self.device)
self.agg_model.eval()
@torch.no_grad()
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
# Extract patches
coords = extract_tissue_patch_coords(
svs_path, patch_size=self.patch_size, step_size=self.step_size
)
# Extract patch-level features
self.feature_extractor.eval()
patch_dataset = WSIPatchDataset(
coords=coords,
wsi_path=svs_path,
pretrained=True,
macenko=self.macenko,
patch_size=self.patch_size,
)
patch_loader = DataLoader(
dataset=patch_dataset,
batch_size=feature_extractor_batch_size,
num_workers=(
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
),
pin_memory=self.device.type == "cuda",
)
features_list = []
for count, patches in enumerate(patch_loader):
print(
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
end="\r",
)
patches = patches.to(self.device, non_blocking=True)
feature = self.feature_extractor(patches) # [B, 1024]
feature /= feature.norm(dim=-1, keepdim=True) # use normalized feature
feature = feature.to("cpu", non_blocking=True)
features_list.append(feature)
print("")
print("Feature extraction finished")
features = torch.cat(features_list)
# Aggregate features
self.agg_model.eval()
logits, Y_prob, Y_hat = self.agg_model(features[None].to(self.device))
probs = Y_prob[0].cpu()
return probs