--- library_name: project-lighter tags: - lighter - model_hub_mixin - pytorch_model_hub_mixin language: en license: apache-2.0 arxiv: 2501.09001 --- # CT-FM Feature Extractor This model is a feature extractor for CT-FM, a model pre-trained using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons. The backbone is based on a SegResNet, a 3D U-Net variant. If you want to just load the model and fine-tune, ignore the feature extraction workflow. ## Running instructions # CT-FM Feature Extractor This notebook demonstrates how to: 1. Load a SSL pre-trained model 2. Set up preprocessing and postprocessing pipelines 3. Perform inference on CT volumes 4. Plot distribution of features extracted ## Setup Install requirements and import necessary packages ```python # Install lighter_zoo package %pip install lighter_zoo -U -qq ``` Note: you may need to restart the kernel to use updated packages. ```python # Imports import torch from lighter_zoo import SegResEncoder from monai.transforms import ( Compose, LoadImage, EnsureType, Orientation, ScaleIntensityRange, CropForeground ) from monai.inferers import SlidingWindowInferer ``` ## Load Model Download and initialize the pre-trained model from HuggingFace Hub ```python # Load pre-trained model model = SegResEncoder.from_pretrained( "project-lighter/ct_fm_feature_extractor" ) ``` ## Setup Processing Pipelines Define preprocessing transforms ```python # Preprocessing pipeline preprocess = Compose([ LoadImage(ensure_channel_first=True), # Load image and ensure channel dimension EnsureType(), # Ensure correct data type Orientation(axcodes="SPL"), # Standardize orientation # Scale intensity to [0,1] range, clipping outliers ScaleIntensityRange( a_min=-1024, # Min HU value a_max=2048, # Max HU value b_min=0, # Target min b_max=1, # Target max clip=True # Clip values outside range ), CropForeground() # Remove background to reduce computation ]) ``` monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5. ## Run Inference Process an input CT scan and extract features ```python # Input path input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz" # Preprocess input input_tensor = preprocess(input_path) # Run inference with torch.no_grad(): output = model(input_tensor.unsqueeze(0))[-1] # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and # use the output tensor directly which will give you the feature maps in a low-dimensional space. avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze() print("✅ Feature extraction completed") print(f"Output shape: {avg_output.shape}") ``` ✅ Feature extraction completed Output shape: torch.Size([512]) ```python # Plot distribution of features import matplotlib.pyplot as plt _ = plt.hist(avg_output.cpu().numpy(), bins=100) ``` ![png](ct_fm_feature_extractor_files/ct_fm_feature_extractor_10_0.png) ```python ```