CT-FM Feature Extractor
This model is a feature extractor for CT-FM, a model pre-trained using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons.
The backbone is based on a SegResNet, a 3D U-Net variant. If you want to just load the model and fine-tune, ignore the feature extraction workflow.
Running instructions
CT-FM Feature Extractor
This notebook demonstrates how to:
- Load a SSL pre-trained model
- Set up preprocessing and postprocessing pipelines
- Perform inference on CT volumes
- Plot distribution of features extracted
Setup
Install requirements and import necessary packages
# Install lighter_zoo package
%pip install lighter_zoo -U -qq
Note: you may need to restart the kernel to use updated packages.
# Imports
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
Compose, LoadImage, EnsureType, Orientation,
ScaleIntensityRange, CropForeground
)
from monai.inferers import SlidingWindowInferer
Load Model
Download and initialize the pre-trained model from HuggingFace Hub
# Load pre-trained model
model = SegResEncoder.from_pretrained(
"project-lighter/ct_fm_feature_extractor"
)
Setup Processing Pipelines
Define preprocessing transforms
# Preprocessing pipeline
preprocess = Compose([
LoadImage(ensure_channel_first=True), # Load image and ensure channel dimension
EnsureType(), # Ensure correct data type
Orientation(axcodes="SPL"), # Standardize orientation
# Scale intensity to [0,1] range, clipping outliers
ScaleIntensityRange(
a_min=-1024, # Min HU value
a_max=2048, # Max HU value
b_min=0, # Target min
b_max=1, # Target max
clip=True # Clip values outside range
),
CropForeground() # Remove background to reduce computation
])
monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
Run Inference
Process an input CT scan and extract features
# Input path
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz"
# Preprocess input
input_tensor = preprocess(input_path)
# Run inference
with torch.no_grad():
output = model(input_tensor.unsqueeze(0))[-1]
# Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and
# use the output tensor directly which will give you the feature maps in a low-dimensional space.
avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()
print("โ
Feature extraction completed")
print(f"Output shape: {avg_output.shape}")
โ
Feature extraction completed
Output shape: torch.Size([512])
# Plot distribution of features
import matplotlib.pyplot as plt
_ = plt.hist(avg_output.cpu().numpy(), bins=100)
- Downloads last month
- 297
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no pipeline_tag.