"""BioNeMo related operations The intention is to showcase how BioNeMo can be integrated with LynxKite. This should be considered as a reference implementation and not a production ready code. The operations are quite specific for this example notebook: https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/examples/bionemo-geneformer/geneformer-celltype-classification.ipynb """ from lynxkite.core import ops import requests import tarfile import os from collections import Counter from . import core import joblib import numpy as np import torch from pathlib import Path import random from contextlib import contextmanager import cellxgene_census # TODO: This needs numpy < 2 import tempfile from sklearn.ensemble import RandomForestClassifier from sklearn.pipeline import Pipeline from sklearn.model_selection import StratifiedKFold, cross_validate from sklearn.metrics import ( make_scorer, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, ) from sklearn.decomposition import PCA from sklearn.model_selection import cross_val_predict from sklearn.preprocessing import LabelEncoder from bionemo.scdl.io.single_cell_collection import SingleCellCollection import scanpy mem = joblib.Memory(".joblib-cache") op = ops.op_registration(core.ENV) DATA_PATH = Path("/workspace") @contextmanager def random_seed(seed: int): state = random.getstate() random.seed(seed) try: yield finally: # Go back to previous state random.setstate(state) @op("BioNeMo > Download CELLxGENE dataset") @mem.cache() def download_cellxgene_dataset( *, save_path: str, census_version: str = "2023-12-15", organism: str = "Homo sapiens", value_filter='dataset_id=="8e47ed12-c658-4252-b126-381df8d52a3d"', max_workers: int = 1, use_mp: bool = False, ) -> None: """Downloads a CELLxGENE dataset""" with cellxgene_census.open_soma(census_version=census_version) as census: adata = cellxgene_census.get_anndata( census, organism, obs_value_filter=value_filter, ) with random_seed(32): indices = list(range(len(adata))) random.shuffle(indices) micro_batch_size: int = 32 num_steps: int = 256 selection = sorted(indices[: micro_batch_size * num_steps]) # NOTE: there's a current constraint that predict_step needs to be a function of micro-batch-size. # this is something we are working on fixing. A quick hack is to set micro-batch-size=1, but this is # slow. In this notebook we are going to use mbs=32 and subsample the anndata. adata = adata[selection].copy() # so it's not a view h5ad_outfile = DATA_PATH / Path("hs-celltype-bench.h5ad") adata.write_h5ad(h5ad_outfile) with tempfile.TemporaryDirectory() as temp_dir: coll = SingleCellCollection(temp_dir) coll.load_h5ad_multi( h5ad_outfile.parent, max_workers=max_workers, use_processes=use_mp ) coll.flatten(DATA_PATH / save_path, destroy_on_copy=True) return DATA_PATH / save_path @op("BioNeMo > Import H5AD file") def import_h5ad(*, file_path: str): return scanpy.read_h5ad(DATA_PATH / Path(file_path)) @op("BioNeMo > Download model") @mem.cache(verbose=1) def download_model(*, model_name: str) -> str: """Downloads a model.""" model_download_parameters = { "geneformer_100m": { "name": "geneformer_100m", "version": "2.0", "path": "geneformer_106M_240530_nemo2", }, "geneformer_10m": { "name": "geneformer_10m", "version": "2.0", "path": "geneformer_10M_240530_nemo2", }, "geneformer_10m2": { "name": "geneformer_10m", "version": "2.1", "path": "geneformer_10M_241113_nemo2", }, } # Define the URL and output file url_template = "https://api.ngc.nvidia.com/v2/models/org/nvidia/team/clara/{name}/{version}/files?redirect=true&path={path}.tar.gz" url = url_template.format(**model_download_parameters[model_name]) model_filename = f"{DATA_PATH}/{model_download_parameters[model_name]['path']}" output_file = f"{model_filename}.tar.gz" # Send the request response = requests.get(url, allow_redirects=True, stream=True) response.raise_for_status() # Raise an error for bad responses (4xx and 5xx) # Save the file to disk with open(f"{output_file}", "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) # Extract the tar.gz file os.makedirs(model_filename, exist_ok=True) with tarfile.open(output_file, "r:gz") as tar: tar.extractall(path=model_filename) return model_filename @op("BioNeMo > Infer") @mem.cache(verbose=1) def infer( dataset_path: str, model_path: str | None = None, *, results_path: str ) -> str: """Infer on a dataset.""" # This import is slow, so we only import it when we need it. from bionemo.geneformer.scripts.infer_geneformer import infer_model infer_model( data_path=dataset_path, checkpoint_path=model_path, results_path=DATA_PATH / results_path, include_hiddens=False, micro_batch_size=32, include_embeddings=True, include_logits=False, seq_length=2048, precision="bf16-mixed", devices=1, num_nodes=1, num_dataset_workers=10, ) return DATA_PATH / results_path @op("BioNeMo > Load results") def load_results(results_path: str): embeddings = ( torch.load(f"{results_path}/predictions__rank_0.pt")["embeddings"] .float() .cpu() .numpy() ) return embeddings @op("BioNeMo > Get labels") def get_labels(adata): infer_metadata = adata.obs labels = infer_metadata["cell_type"].values label_encoder = LabelEncoder() integer_labels = label_encoder.fit_transform(labels) label_encoder.integer_labels = integer_labels return label_encoder @op("BioNeMo > Plot labels", view="visualization") def plot_labels(adata): infer_metadata = adata.obs labels = infer_metadata["cell_type"].values label_counts = Counter(labels) labels = list(label_counts.keys()) values = list(label_counts.values()) options = { "title": { "text": "Cell type counts for classification dataset", "left": "center", }, "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, "xAxis": { "type": "category", "data": labels, "axisLabel": {"rotate": 45, "align": "right"}, }, "yAxis": {"type": "value"}, "series": [ { "name": "Count", "type": "bar", "data": values, "itemStyle": {"color": "#4285F4"}, } ], } return options @op("BioNeMo > Run benchmark") @mem.cache(verbose=1) def run_benchmark(data, labels, *, use_pca: bool = False): """ data - contains the single cell expression (or whatever feature) in each row. labels - contains the string label for each cell data_shape (R, C) labels_shape (R,) """ np.random.seed(1337) # Define the target dimension 'n_components' n_components = 10 # for example, adjust based on your specific needs # Create a pipeline that includes Gaussian random projection and RandomForestClassifier if use_pca: pipeline = Pipeline( [ ("projection", PCA(n_components=n_components)), ("classifier", RandomForestClassifier(class_weight="balanced")), ] ) else: pipeline = Pipeline( [("classifier", RandomForestClassifier(class_weight="balanced"))] ) # Set up StratifiedKFold to ensure each fold reflects the overall distribution of labels cv = StratifiedKFold(n_splits=5) # Define the scoring functions scoring = { "accuracy": make_scorer(accuracy_score), "precision": make_scorer( precision_score, average="macro" ), # 'macro' averages over classes "recall": make_scorer(recall_score, average="macro"), "f1_score": make_scorer(f1_score, average="macro"), # 'roc_auc' requires probability or decision function; hence use multi_class if applicable "roc_auc": make_scorer(roc_auc_score, multi_class="ovr"), } labels = labels.integer_labels # Perform stratified cross-validation with multiple metrics using the pipeline results = cross_validate( pipeline, data, labels, cv=cv, scoring=scoring, return_train_score=False ) # Print the cross-validation results print("Cross-validation metrics:") results_out = {} for metric, scores in results.items(): if metric.startswith("test_"): results_out[metric] = (scores.mean(), scores.std()) print(f"{metric[5:]}: {scores.mean():.3f} (+/- {scores.std():.3f})") predictions = cross_val_predict(pipeline, data, labels, cv=cv) # v Return confusion matrix and metrics. conf_matrix = confusion_matrix(labels, predictions) return results_out, conf_matrix @op("BioNeMo > Plot confusion matrix", view="visualization") @mem.cache(verbose=1) def plot_confusion_matrix(benchmark_output, labels): cm = benchmark_output[1] labels = labels.classes_ str_labels = [str(label) for label in labels] norm_cm = [[float(val / sum(row)) if sum(row) else 0 for val in row] for row in cm] # heatmap has the 0,0 at the bottom left corner num_rows = len(str_labels) heatmap_data = [ [j, num_rows - i - 1, norm_cm[i][j]] for i in range(len(labels)) for j in range(len(labels)) ] options = { "title": {"text": "Confusion Matrix", "left": "center"}, "tooltip": {"position": "top"}, "xAxis": { "type": "category", "data": str_labels, "splitArea": {"show": True}, "axisLabel": {"rotate": 70, "align": "right"}, }, "yAxis": { "type": "category", "data": list(reversed(str_labels)), "splitArea": {"show": True}, }, "grid": { "height": "70%", "width": "70%", "left": "20%", "right": "10%", "bottom": "10%", "top": "10%", }, "visualMap": { "min": 0, "max": 1, "calculable": True, "orient": "vertical", "right": 10, "top": "center", "inRange": { "color": ["#E0F7FA", "#81D4FA", "#29B6F6", "#0288D1", "#01579B"] }, }, "series": [ { "name": "Confusion matrix", "type": "heatmap", "data": heatmap_data, "emphasis": {"itemStyle": {"borderColor": "#333", "borderWidth": 1}}, "itemStyle": {"borderColor": "#D3D3D3", "borderWidth": 2}, } ], } return options @op("BioNeMo > Plot accuracy comparison", view="visualization") def accuracy_comparison(benchmark_output10m, benchmark_output100m): results_10m = benchmark_output10m[0] results_106M = benchmark_output100m[0] data = { "model": ["10M parameters", "106M parameters"], "accuracy_mean": [ results_10m["test_accuracy"][0], results_106M["test_accuracy"][0], ], "accuracy_std": [ results_10m["test_accuracy"][1], results_106M["test_accuracy"][1], ], } labels = data["model"] # X-axis labels values = data["accuracy_mean"] # Y-axis values error_bars = data["accuracy_std"] # Standard deviation for error bars options = { "title": { "text": "Accuracy Comparison", "left": "center", "textStyle": { "fontSize": 20, # Bigger font for title "fontWeight": "bold", # Make title bold }, }, "grid": { "height": "70%", "width": "70%", "left": "20%", "right": "10%", "bottom": "10%", "top": "10%", }, "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, "xAxis": { "type": "category", "data": labels, "axisLabel": { "rotate": 45, # Rotate labels for better readability "align": "right", "textStyle": { "fontSize": 14, # Bigger font for X-axis labels "fontWeight": "bold", }, }, }, "yAxis": { "type": "value", "name": "Accuracy", "min": 0, "max": 1, "interval": 0.1, # Matches np.arange(0, 1.05, 0.05) "axisLabel": { "textStyle": { "fontSize": 14, # Bigger font for X-axis labels "fontWeight": "bold", } }, }, "series": [ { "name": "Accuracy", "type": "bar", "data": values, "itemStyle": { "color": "#440154" # Viridis color palette (dark purple) }, }, { "name": "Error Bars", "type": "errorbar", "data": [ [val - err, val + err] for val, err in zip(values, error_bars) ], "itemStyle": {"color": "#1f77b4"}, }, ], } return options @op("BioNeMo > Plot f1 comparison", view="visualization") def f1_comparison(benchmark_output10m, benchmark_output100m): results_10m = benchmark_output10m[0] results_106M = benchmark_output100m[0] data = { "model": ["10M parameters", "106M parameters"], "f1_score_mean": [ results_10m["test_f1_score"][0], results_106M["test_f1_score"][0], ], "f1_score_std": [ results_10m["test_f1_score"][1], results_106M["test_f1_score"][1], ], } labels = data["model"] # X-axis labels values = data["f1_score_mean"] # Y-axis values error_bars = data["f1_score_std"] # Standard deviation for error bars options = { "title": { "text": "F1 Score Comparison", "left": "center", "textStyle": { "fontSize": 20, # Bigger font for title "fontWeight": "bold", # Make title bold }, }, "grid": { "height": "70%", "width": "70%", "left": "20%", "right": "10%", "bottom": "10%", "top": "10%", }, "tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, "xAxis": { "type": "category", "data": labels, "axisLabel": { "rotate": 45, # Rotate labels for better readability "align": "right", "textStyle": { "fontSize": 14, # Bigger font for X-axis labels "fontWeight": "bold", }, }, }, "yAxis": { "type": "value", "name": "F1 Score", "min": 0, "max": 1, "interval": 0.1, # Matches np.arange(0, 1.05, 0.05), "axisLabel": { "textStyle": { "fontSize": 14, # Bigger font for X-axis labels "fontWeight": "bold", } }, }, "series": [ { "name": "F1 Score", "type": "bar", "data": values, "itemStyle": { "color": "#440154" # Viridis color palette (dark purple) }, }, { "name": "Error Bars", "type": "errorbar", "data": [ [val - err, val + err] for val, err in zip(values, error_bars) ], "itemStyle": {"color": "#1f77b4"}, }, ], } return options