CT-FM Feature Extractor

This model is a feature extractor for CT-FM, a model pre-trained using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons.

The backbone is based on a SegResNet, a 3D U-Net variant. If you want to just load the model and fine-tune, ignore the feature extraction workflow.

Running instructions

CT-FM Feature Extractor

This notebook demonstrates how to:

  1. Load a SSL pre-trained model
  2. Set up preprocessing and postprocessing pipelines
  3. Perform inference on CT volumes
  4. Plot distribution of features extracted

Setup

Install requirements and import necessary packages

# Install lighter_zoo package
%pip install lighter_zoo -U -qq
Note: you may need to restart the kernel to use updated packages.
# Imports
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground
)
from monai.inferers import SlidingWindowInferer

Load Model

Download and initialize the pre-trained model from HuggingFace Hub

# Load pre-trained model
model = SegResEncoder.from_pretrained(
    "project-lighter/ct_fm_feature_extractor"
)

Setup Processing Pipelines

Define preprocessing transforms

# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground()    # Remove background to reduce computation
])
monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.

Run Inference

Process an input CT scan and extract features

# Input path
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz"

# Preprocess input
input_tensor = preprocess(input_path)

# Run inference
with torch.no_grad():
    output = model(input_tensor.unsqueeze(0))[-1]

    # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and 
    # use the output tensor directly which will give you the feature maps in a low-dimensional space.
    avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()

print("โœ… Feature extraction completed")
print(f"Output shape: {avg_output.shape}")
โœ… Feature extraction completed
Output shape: torch.Size([512])
# Plot distribution of features
import matplotlib.pyplot as plt
_ = plt.hist(avg_output.cpu().numpy(), bins=100)

png


Downloads last month
297
Safetensors
Model size
77.8M params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.