CxTissueSeg
Overview
The CxTissueSeg model performs binary segmentation of patches of tissue present in H&E pathology slides. It is architected to run efficiently on resource constrained systems, providing tissue segmentation on a slide in under 1 second on a typical CPU.
The model is trained on a manually curated set of slides from our linked dataset, where it achieves 0.93 mIoU for tissue on the test split. By default, the model outputs logits, where the positive class is predicted tissue and the negative class is predicted backgound. It is recommended to use the model with our open source tiled inference framework, which will handle running inference on a full image through tiling and stitching results.
This model was trained using PyTorch and Segmentation Models PyTorch.
It uses a UNet decoder with a MobileNet-v3 encoder -- specifically, we use timm/mobilenetv3_small_100
as the encoder.
We provide the model weights in both a pickled format (model.pth
) and via safetensors (model.safetensors
).
We also provide the model exported to ONNX (model.onnx
) to be used with ONNX Runtime so it can be run even more efficiently and across programming languages.
To try a demo of the model being run in the browser vai ONNX Runtime, see: http://www.conflux.xyz/demos/tissue-segmentation.
We also provide a statically quantized model (int8) usable via ONNX Runtime with model_qint8.onnx
, although its performance is not on par with the full float32 model (0.85 mIoU rather than 0.93 mIoU).
For more details on the background of the model, check out the blog post here: http://www.conflux.xyz/blog/tissue-segmentation.
Usage
CxTissueSeg was trained on 512 x 512 pixel patches from thumbnail images of whole slides at 40 microns per pixel (MPP) -- a 4x downsample from the images in the dataset.
Thus, it is important when running inference with the model to run it on 40 MPP thumbnails and run inference on tiles of the same dimension (512 x 512).
When padding tiles, pad with pure white: rgb(255, 255, 255)
.
To make this easier, we provide a more general segmentation library to aid in performing tiled inference: https://github.com/conflux-xyz/conflux-segmentation.
Create a segmentation model
ONNX
# pip install conflux-segmentation[onnx] onnxruntime
import onnxruntime as ort
from conflux_segmentation import Segmenter
session = ort.InferenceSession("/path/to/model.onnx")
segmenter = Segmenter.from_onnx(session, activation="sigmoid")
PyTorch
# pip install conflux-segmentation[torch] torch segmentation-models-pytorch
import segmentation_models_pytorch as smp
from conflux_segmentation import Segmenter
net = smp.Unet(encoder_name="tu-mobilenetv3_small_100", encoder_weights=None, activation=None)
net.load_state_dict(torch.load("/path/to/model.pth", weights_only=True))
# alternatively with safetensors:
# net.load_state_dict(safetensors.torch.load_file("/path/to/model.safetensors"))
net.eval()
# Optionally, trace the model to get a TorchScript ScriptModule
# example = torch.randn(1, 3, 512, 512)
# net = torch.jit.trace(net, example)
# net.eval()
segmenter = Segmenter.from_torch(net, activation="sigmoid")
Segment!
import cv2
# A 40 MPP thumbnail: H x W x 3 image array of np.uint8
image = cv2.cvtColor(cv2.imread("/path/to/large/image"), cv2.COLOR_BGR2RGB)
# Alternatively, use `openslide` or `tiffslide` to get a 40 MPP thumbnail
# H x W boolean array
mask = segmenter(image).to_binary().get_mask()
tissue_fraction = mask.sum() / mask.size
print(f"Fraction of slide with tissue: {tissue_fraction:.3f}")
Acknowledgements
We are grateful to the TCGA Research Network from which the slides used for training were originally sourced.
Per their citation request (https://www.cancer.gov/ccg/research/genome-sequencing/tcga/using-tcga-data/citing),
The results shown here are in whole or part based upon data generated by the TCGA Research Network: https://www.cancer.gov/tcga.
Model tree for conflux-xyz/cx-tissue-seg
Base model
timm/mobilenetv3_small_100.lamb_in1k