File size: 3,114 Bytes
287a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import torch
import torch.nn as nn
from scripts.aggregator import Perceiver
from scripts.dataset import WSIPatchDataset
from scripts.feature_extractor import vit_base
from scripts.wsi_utils import extract_tissue_patch_coords
from torch.utils.data import DataLoader


class EXAONEPathV1Downstream(nn.Module):
    def __init__(
        self, device: torch.device, step_size=256, patch_size=256, macenko=True
    ):
        super(EXAONEPathV1Downstream, self).__init__()
        self.step_size = step_size
        self.patch_size = patch_size
        self.macenko = macenko
        self.device = device

        self.feature_extractor = vit_base()
        self.feature_extractor = self.feature_extractor
        self.feature_extractor = self.feature_extractor.to(self.device)
        self.feature_extractor.eval()

        self.agg_model = Perceiver(
            input_channels=768,
            input_axis=1,
            num_freq_bands=6,
            max_freq=10.0,
            depth=6,
            num_latents=256,
            latent_dim=512,
            cross_heads=1,
            latent_heads=8,
            cross_dim_head=64,
            latent_dim_head=64,
            num_classes=2,
            fourier_encode_data=False,
            self_per_cross_attn=2,
            pool="mean",
        )
        self.agg_model.to(self.device)
        self.agg_model.eval()

    @torch.no_grad()
    def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
        # Extract patches
        coords = extract_tissue_patch_coords(
            svs_path, patch_size=self.patch_size, step_size=self.step_size
        )

        # Extract patch-level features
        self.feature_extractor.eval()
        patch_dataset = WSIPatchDataset(
            coords=coords,
            wsi_path=svs_path,
            pretrained=True,
            macenko=self.macenko,
            patch_size=self.patch_size,
        )
        patch_loader = DataLoader(
            dataset=patch_dataset,
            batch_size=feature_extractor_batch_size,
            num_workers=(
                feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
            ),
            pin_memory=self.device.type == "cuda",
        )
        features_list = []
        for count, patches in enumerate(patch_loader):
            print(
                f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
                end="\r",
            )
            patches = patches.to(self.device, non_blocking=True)

            feature = self.feature_extractor(patches)  # [B, 1024]
            feature /= feature.norm(dim=-1, keepdim=True)  # use normalized feature
            feature = feature.to("cpu", non_blocking=True)
            features_list.append(feature)
        print("")
        print("Feature extraction finished")

        features = torch.cat(features_list)

        # Aggregate features
        self.agg_model.eval()
        logits, Y_prob, Y_hat = self.agg_model(features[None].to(self.device))
        probs = Y_prob[0].cpu()

        return probs