File size: 3,397 Bytes
6a4c086
 
 
 
 
 
b8a8864
 
 
6a4c086
 
b8a8864
 
a005e18
b8a8864
a005e18
b8a8864
 
 
 
23e05e9
b8a8864
 
23e05e9
b8a8864
23e05e9
 
b8a8864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23e05e9
b8a8864
 
23e05e9
b8a8864
 
 
 
 
 
 
 
 
 
23e05e9
 
b8a8864
 
 
 
23e05e9
b8a8864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23e05e9
b8a8864
 
 
23e05e9
b8a8864
 
 
 
 
 
 
 
 
 
 
23e05e9
b8a8864
23e05e9
 
 
b8a8864
23e05e9
 
b8a8864
 
23e05e9
 
 
 
 
 
 
 
 
 
b8a8864
 
23e05e9
 
 
 
 
 
 
 
 
 
b8a8864
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
---
library_name: project-lighter
tags:
- lighter
- model_hub_mixin
- pytorch_model_hub_mixin
language: en
license: apache-2.0
arxiv: 2501.09001
---

# CT-FM Feature Extractor

This model is a feature extractor for CT-FM, a model pre-trained using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons. 

The backbone is based on a SegResNet, a 3D U-Net variant. If you want to just load the model and fine-tune, ignore the feature extraction workflow. 

## Running instructions


# CT-FM Feature Extractor

This notebook demonstrates how to:
1. Load a SSL pre-trained model
2. Set up preprocessing and postprocessing pipelines
3. Perform inference on CT volumes
4. Plot distribution of features extracted

## Setup
Install requirements and import necessary packages


```python
# Install lighter_zoo package
%pip install lighter_zoo -U -qq
```

    Note: you may need to restart the kernel to use updated packages.



```python
# Imports
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground
)
from monai.inferers import SlidingWindowInferer
```

## Load Model
Download and initialize the pre-trained model from HuggingFace Hub


```python
# Load pre-trained model
model = SegResEncoder.from_pretrained(
    "project-lighter/ct_fm_feature_extractor"
)
```

## Setup Processing Pipelines
Define preprocessing transforms


```python
# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground()    # Remove background to reduce computation
])
```

    monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.


## Run Inference
Process an input CT scan and extract features


```python
# Input path
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz"

# Preprocess input
input_tensor = preprocess(input_path)

# Run inference
with torch.no_grad():
    output = model(input_tensor.unsqueeze(0))[-1]

    # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and 
    # use the output tensor directly which will give you the feature maps in a low-dimensional space.
    avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()

print("✅ Feature extraction completed")
print(f"Output shape: {avg_output.shape}")
```

    ✅ Feature extraction completed
    Output shape: torch.Size([512])



```python
# Plot distribution of features
import matplotlib.pyplot as plt
_ = plt.hist(avg_output.cpu().numpy(), bins=100)
```


    
![png](ct_fm_feature_extractor_files/ct_fm_feature_extractor_10_0.png)
    



```python

```