surajpaib's picture
Upload README.md with huggingface_hub
23e05e9 verified
metadata
library_name: project-lighter
tags:
  - lighter
  - model_hub_mixin
  - pytorch_model_hub_mixin
language: en
license: apache-2.0
arxiv: 2501.09001

CT-FM Feature Extractor

This model is a feature extractor for CT-FM, a model pre-trained using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons.

The backbone is based on a SegResNet, a 3D U-Net variant. If you want to just load the model and fine-tune, ignore the feature extraction workflow.

Running instructions

CT-FM Feature Extractor

This notebook demonstrates how to:

  1. Load a SSL pre-trained model
  2. Set up preprocessing and postprocessing pipelines
  3. Perform inference on CT volumes
  4. Plot distribution of features extracted

Setup

Install requirements and import necessary packages

# Install lighter_zoo package
%pip install lighter_zoo -U -qq
Note: you may need to restart the kernel to use updated packages.
# Imports
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground
)
from monai.inferers import SlidingWindowInferer

Load Model

Download and initialize the pre-trained model from HuggingFace Hub

# Load pre-trained model
model = SegResEncoder.from_pretrained(
    "project-lighter/ct_fm_feature_extractor"
)

Setup Processing Pipelines

Define preprocessing transforms

# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground()    # Remove background to reduce computation
])
monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.

Run Inference

Process an input CT scan and extract features

# Input path
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz"

# Preprocess input
input_tensor = preprocess(input_path)

# Run inference
with torch.no_grad():
    output = model(input_tensor.unsqueeze(0))[-1]

    # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and 
    # use the output tensor directly which will give you the feature maps in a low-dimensional space.
    avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()

print("✅ Feature extraction completed")
print(f"Output shape: {avg_output.shape}")
✅ Feature extraction completed
Output shape: torch.Size([512])
# Plot distribution of features
import matplotlib.pyplot as plt
_ = plt.hist(avg_output.cpu().numpy(), bins=100)

png