import numpy as np import torch import torch.nn as nn from scripts.aggregator import Perceiver from scripts.dataset import WSIPatchDataset from scripts.feature_extractor import vit_base from scripts.wsi_utils import extract_tissue_patch_coords from torch.utils.data import DataLoader class EXAONEPathV1Downstream(nn.Module): def __init__( self, device: torch.device, step_size=256, patch_size=256, macenko=True ): super(EXAONEPathV1Downstream, self).__init__() self.step_size = step_size self.patch_size = patch_size self.macenko = macenko self.device = device self.feature_extractor = vit_base() self.feature_extractor = self.feature_extractor self.feature_extractor = self.feature_extractor.to(self.device) self.feature_extractor.eval() self.agg_model = Perceiver( input_channels=768, input_axis=1, num_freq_bands=6, max_freq=10.0, depth=6, num_latents=256, latent_dim=512, cross_heads=1, latent_heads=8, cross_dim_head=64, latent_dim_head=64, num_classes=2, fourier_encode_data=False, self_per_cross_attn=2, pool="mean", ) self.agg_model.to(self.device) self.agg_model.eval() @torch.no_grad() def forward(self, svs_path: str, feature_extractor_batch_size: int = 8): # Extract patches coords = extract_tissue_patch_coords( svs_path, patch_size=self.patch_size, step_size=self.step_size ) # Extract patch-level features self.feature_extractor.eval() patch_dataset = WSIPatchDataset( coords=coords, wsi_path=svs_path, pretrained=True, macenko=self.macenko, patch_size=self.patch_size, ) patch_loader = DataLoader( dataset=patch_dataset, batch_size=feature_extractor_batch_size, num_workers=( feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 ), pin_memory=self.device.type == "cuda", ) features_list = [] for count, patches in enumerate(patch_loader): print( f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", end="\r", ) patches = patches.to(self.device, non_blocking=True) feature = self.feature_extractor(patches) # [B, 1024] feature /= feature.norm(dim=-1, keepdim=True) # use normalized feature feature = feature.to("cpu", non_blocking=True) features_list.append(feature) print("") print("Feature extraction finished") features = torch.cat(features_list) # Aggregate features self.agg_model.eval() logits, Y_prob, Y_hat = self.agg_model(features[None].to(self.device)) probs = Y_prob[0].cpu() return probs