Spaces:
Running
Running
"""BioNeMo related operations | |
The intention is to showcase how BioNeMo can be integrated with LynxKite. This should be | |
considered as a reference implementation and not a production ready code. | |
The operations are quite specific for this example notebook: | |
https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/examples/bionemo-geneformer/geneformer-celltype-classification.ipynb | |
""" | |
from lynxkite.core import ops | |
import requests | |
import tarfile | |
import os | |
from collections import Counter | |
from . import core | |
import joblib | |
import numpy as np | |
import torch | |
from pathlib import Path | |
import random | |
from contextlib import contextmanager | |
import cellxgene_census # TODO: This needs numpy < 2 | |
import tempfile | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.pipeline import Pipeline | |
from sklearn.model_selection import StratifiedKFold, cross_validate | |
from sklearn.metrics import ( | |
make_scorer, | |
accuracy_score, | |
precision_score, | |
recall_score, | |
f1_score, | |
roc_auc_score, | |
confusion_matrix, | |
) | |
from sklearn.decomposition import PCA | |
from sklearn.model_selection import cross_val_predict | |
from sklearn.preprocessing import LabelEncoder | |
from bionemo.scdl.io.single_cell_collection import SingleCellCollection | |
import scanpy | |
mem = joblib.Memory(".joblib-cache") | |
op = ops.op_registration(core.ENV) | |
DATA_PATH = Path("/workspace") | |
def random_seed(seed: int): | |
state = random.getstate() | |
random.seed(seed) | |
try: | |
yield | |
finally: | |
# Go back to previous state | |
random.setstate(state) | |
def download_cellxgene_dataset( | |
*, | |
save_path: str, | |
census_version: str = "2023-12-15", | |
organism: str = "Homo sapiens", | |
value_filter='dataset_id=="8e47ed12-c658-4252-b126-381df8d52a3d"', | |
max_workers: int = 1, | |
use_mp: bool = False, | |
) -> None: | |
"""Downloads a CELLxGENE dataset""" | |
with cellxgene_census.open_soma(census_version=census_version) as census: | |
adata = cellxgene_census.get_anndata( | |
census, | |
organism, | |
obs_value_filter=value_filter, | |
) | |
with random_seed(32): | |
indices = list(range(len(adata))) | |
random.shuffle(indices) | |
micro_batch_size: int = 32 | |
num_steps: int = 256 | |
selection = sorted(indices[: micro_batch_size * num_steps]) | |
# NOTE: there's a current constraint that predict_step needs to be a function of micro-batch-size. | |
# this is something we are working on fixing. A quick hack is to set micro-batch-size=1, but this is | |
# slow. In this notebook we are going to use mbs=32 and subsample the anndata. | |
adata = adata[selection].copy() # so it's not a view | |
h5ad_outfile = DATA_PATH / Path("hs-celltype-bench.h5ad") | |
adata.write_h5ad(h5ad_outfile) | |
with tempfile.TemporaryDirectory() as temp_dir: | |
coll = SingleCellCollection(temp_dir) | |
coll.load_h5ad_multi( | |
h5ad_outfile.parent, max_workers=max_workers, use_processes=use_mp | |
) | |
coll.flatten(DATA_PATH / save_path, destroy_on_copy=True) | |
return DATA_PATH / save_path | |
def import_h5ad(*, file_path: str): | |
return scanpy.read_h5ad(DATA_PATH / Path(file_path)) | |
def download_model(*, model_name: str) -> str: | |
"""Downloads a model.""" | |
model_download_parameters = { | |
"geneformer_100m": { | |
"name": "geneformer_100m", | |
"version": "2.0", | |
"path": "geneformer_106M_240530_nemo2", | |
}, | |
"geneformer_10m": { | |
"name": "geneformer_10m", | |
"version": "2.0", | |
"path": "geneformer_10M_240530_nemo2", | |
}, | |
"geneformer_10m2": { | |
"name": "geneformer_10m", | |
"version": "2.1", | |
"path": "geneformer_10M_241113_nemo2", | |
}, | |
} | |
# Define the URL and output file | |
url_template = "https://api.ngc.nvidia.com/v2/models/org/nvidia/team/clara/{name}/{version}/files?redirect=true&path={path}.tar.gz" | |
url = url_template.format(**model_download_parameters[model_name]) | |
model_filename = f"{DATA_PATH}/{model_download_parameters[model_name]['path']}" | |
output_file = f"{model_filename}.tar.gz" | |
# Send the request | |
response = requests.get(url, allow_redirects=True, stream=True) | |
response.raise_for_status() # Raise an error for bad responses (4xx and 5xx) | |
# Save the file to disk | |
with open(f"{output_file}", "wb") as file: | |
for chunk in response.iter_content(chunk_size=8192): | |
file.write(chunk) | |
# Extract the tar.gz file | |
os.makedirs(model_filename, exist_ok=True) | |
with tarfile.open(output_file, "r:gz") as tar: | |
tar.extractall(path=model_filename) | |
return model_filename | |
def infer( | |
dataset_path: str, model_path: str | None = None, *, results_path: str | |
) -> str: | |
"""Infer on a dataset.""" | |
# This import is slow, so we only import it when we need it. | |
from bionemo.geneformer.scripts.infer_geneformer import infer_model | |
infer_model( | |
data_path=dataset_path, | |
checkpoint_path=model_path, | |
results_path=DATA_PATH / results_path, | |
include_hiddens=False, | |
micro_batch_size=32, | |
include_embeddings=True, | |
include_logits=False, | |
seq_length=2048, | |
precision="bf16-mixed", | |
devices=1, | |
num_nodes=1, | |
num_dataset_workers=10, | |
) | |
return DATA_PATH / results_path | |
def load_results(results_path: str): | |
embeddings = ( | |
torch.load(f"{results_path}/predictions__rank_0.pt")["embeddings"] | |
.float() | |
.cpu() | |
.numpy() | |
) | |
return embeddings | |
def get_labels(adata): | |
infer_metadata = adata.obs | |
labels = infer_metadata["cell_type"].values | |
label_encoder = LabelEncoder() | |
integer_labels = label_encoder.fit_transform(labels) | |
label_encoder.integer_labels = integer_labels | |
return label_encoder | |
def plot_labels(adata): | |
infer_metadata = adata.obs | |
labels = infer_metadata["cell_type"].values | |
label_counts = Counter(labels) | |
labels = list(label_counts.keys()) | |
values = list(label_counts.values()) | |
options = { | |
"title": { | |
"text": "Cell type counts for classification dataset", | |
"left": "center", | |
}, | |
"tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, | |
"xAxis": { | |
"type": "category", | |
"data": labels, | |
"axisLabel": {"rotate": 45, "align": "right"}, | |
}, | |
"yAxis": {"type": "value"}, | |
"series": [ | |
{ | |
"name": "Count", | |
"type": "bar", | |
"data": values, | |
"itemStyle": {"color": "#4285F4"}, | |
} | |
], | |
} | |
return options | |
def run_benchmark(data, labels, *, use_pca: bool = False): | |
""" | |
data - contains the single cell expression (or whatever feature) in each row. | |
labels - contains the string label for each cell | |
data_shape (R, C) | |
labels_shape (R,) | |
""" | |
np.random.seed(1337) | |
# Define the target dimension 'n_components' | |
n_components = 10 # for example, adjust based on your specific needs | |
# Create a pipeline that includes Gaussian random projection and RandomForestClassifier | |
if use_pca: | |
pipeline = Pipeline( | |
[ | |
("projection", PCA(n_components=n_components)), | |
("classifier", RandomForestClassifier(class_weight="balanced")), | |
] | |
) | |
else: | |
pipeline = Pipeline( | |
[("classifier", RandomForestClassifier(class_weight="balanced"))] | |
) | |
# Set up StratifiedKFold to ensure each fold reflects the overall distribution of labels | |
cv = StratifiedKFold(n_splits=5) | |
# Define the scoring functions | |
scoring = { | |
"accuracy": make_scorer(accuracy_score), | |
"precision": make_scorer( | |
precision_score, average="macro" | |
), # 'macro' averages over classes | |
"recall": make_scorer(recall_score, average="macro"), | |
"f1_score": make_scorer(f1_score, average="macro"), | |
# 'roc_auc' requires probability or decision function; hence use multi_class if applicable | |
"roc_auc": make_scorer(roc_auc_score, multi_class="ovr"), | |
} | |
labels = labels.integer_labels | |
# Perform stratified cross-validation with multiple metrics using the pipeline | |
results = cross_validate( | |
pipeline, data, labels, cv=cv, scoring=scoring, return_train_score=False | |
) | |
# Print the cross-validation results | |
print("Cross-validation metrics:") | |
results_out = {} | |
for metric, scores in results.items(): | |
if metric.startswith("test_"): | |
results_out[metric] = (scores.mean(), scores.std()) | |
print(f"{metric[5:]}: {scores.mean():.3f} (+/- {scores.std():.3f})") | |
predictions = cross_val_predict(pipeline, data, labels, cv=cv) | |
# v Return confusion matrix and metrics. | |
conf_matrix = confusion_matrix(labels, predictions) | |
return results_out, conf_matrix | |
def plot_confusion_matrix(benchmark_output, labels): | |
cm = benchmark_output[1] | |
labels = labels.classes_ | |
str_labels = [str(label) for label in labels] | |
norm_cm = [[float(val / sum(row)) if sum(row) else 0 for val in row] for row in cm] | |
# heatmap has the 0,0 at the bottom left corner | |
num_rows = len(str_labels) | |
heatmap_data = [ | |
[j, num_rows - i - 1, norm_cm[i][j]] | |
for i in range(len(labels)) | |
for j in range(len(labels)) | |
] | |
options = { | |
"title": {"text": "Confusion Matrix", "left": "center"}, | |
"tooltip": {"position": "top"}, | |
"xAxis": { | |
"type": "category", | |
"data": str_labels, | |
"splitArea": {"show": True}, | |
"axisLabel": {"rotate": 70, "align": "right"}, | |
}, | |
"yAxis": { | |
"type": "category", | |
"data": list(reversed(str_labels)), | |
"splitArea": {"show": True}, | |
}, | |
"grid": { | |
"height": "70%", | |
"width": "70%", | |
"left": "20%", | |
"right": "10%", | |
"bottom": "10%", | |
"top": "10%", | |
}, | |
"visualMap": { | |
"min": 0, | |
"max": 1, | |
"calculable": True, | |
"orient": "vertical", | |
"right": 10, | |
"top": "center", | |
"inRange": { | |
"color": ["#E0F7FA", "#81D4FA", "#29B6F6", "#0288D1", "#01579B"] | |
}, | |
}, | |
"series": [ | |
{ | |
"name": "Confusion matrix", | |
"type": "heatmap", | |
"data": heatmap_data, | |
"emphasis": {"itemStyle": {"borderColor": "#333", "borderWidth": 1}}, | |
"itemStyle": {"borderColor": "#D3D3D3", "borderWidth": 2}, | |
} | |
], | |
} | |
return options | |
def accuracy_comparison(benchmark_output10m, benchmark_output100m): | |
results_10m = benchmark_output10m[0] | |
results_106M = benchmark_output100m[0] | |
data = { | |
"model": ["10M parameters", "106M parameters"], | |
"accuracy_mean": [ | |
results_10m["test_accuracy"][0], | |
results_106M["test_accuracy"][0], | |
], | |
"accuracy_std": [ | |
results_10m["test_accuracy"][1], | |
results_106M["test_accuracy"][1], | |
], | |
} | |
labels = data["model"] # X-axis labels | |
values = data["accuracy_mean"] # Y-axis values | |
error_bars = data["accuracy_std"] # Standard deviation for error bars | |
options = { | |
"title": { | |
"text": "Accuracy Comparison", | |
"left": "center", | |
"textStyle": { | |
"fontSize": 20, # Bigger font for title | |
"fontWeight": "bold", # Make title bold | |
}, | |
}, | |
"grid": { | |
"height": "70%", | |
"width": "70%", | |
"left": "20%", | |
"right": "10%", | |
"bottom": "10%", | |
"top": "10%", | |
}, | |
"tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, | |
"xAxis": { | |
"type": "category", | |
"data": labels, | |
"axisLabel": { | |
"rotate": 45, # Rotate labels for better readability | |
"align": "right", | |
"textStyle": { | |
"fontSize": 14, # Bigger font for X-axis labels | |
"fontWeight": "bold", | |
}, | |
}, | |
}, | |
"yAxis": { | |
"type": "value", | |
"name": "Accuracy", | |
"min": 0, | |
"max": 1, | |
"interval": 0.1, # Matches np.arange(0, 1.05, 0.05) | |
"axisLabel": { | |
"textStyle": { | |
"fontSize": 14, # Bigger font for X-axis labels | |
"fontWeight": "bold", | |
} | |
}, | |
}, | |
"series": [ | |
{ | |
"name": "Accuracy", | |
"type": "bar", | |
"data": values, | |
"itemStyle": { | |
"color": "#440154" # Viridis color palette (dark purple) | |
}, | |
}, | |
{ | |
"name": "Error Bars", | |
"type": "errorbar", | |
"data": [ | |
[val - err, val + err] for val, err in zip(values, error_bars) | |
], | |
"itemStyle": {"color": "#1f77b4"}, | |
}, | |
], | |
} | |
return options | |
def f1_comparison(benchmark_output10m, benchmark_output100m): | |
results_10m = benchmark_output10m[0] | |
results_106M = benchmark_output100m[0] | |
data = { | |
"model": ["10M parameters", "106M parameters"], | |
"f1_score_mean": [ | |
results_10m["test_f1_score"][0], | |
results_106M["test_f1_score"][0], | |
], | |
"f1_score_std": [ | |
results_10m["test_f1_score"][1], | |
results_106M["test_f1_score"][1], | |
], | |
} | |
labels = data["model"] # X-axis labels | |
values = data["f1_score_mean"] # Y-axis values | |
error_bars = data["f1_score_std"] # Standard deviation for error bars | |
options = { | |
"title": { | |
"text": "F1 Score Comparison", | |
"left": "center", | |
"textStyle": { | |
"fontSize": 20, # Bigger font for title | |
"fontWeight": "bold", # Make title bold | |
}, | |
}, | |
"grid": { | |
"height": "70%", | |
"width": "70%", | |
"left": "20%", | |
"right": "10%", | |
"bottom": "10%", | |
"top": "10%", | |
}, | |
"tooltip": {"trigger": "axis", "axisPointer": {"type": "shadow"}}, | |
"xAxis": { | |
"type": "category", | |
"data": labels, | |
"axisLabel": { | |
"rotate": 45, # Rotate labels for better readability | |
"align": "right", | |
"textStyle": { | |
"fontSize": 14, # Bigger font for X-axis labels | |
"fontWeight": "bold", | |
}, | |
}, | |
}, | |
"yAxis": { | |
"type": "value", | |
"name": "F1 Score", | |
"min": 0, | |
"max": 1, | |
"interval": 0.1, # Matches np.arange(0, 1.05, 0.05), | |
"axisLabel": { | |
"textStyle": { | |
"fontSize": 14, # Bigger font for X-axis labels | |
"fontWeight": "bold", | |
} | |
}, | |
}, | |
"series": [ | |
{ | |
"name": "F1 Score", | |
"type": "bar", | |
"data": values, | |
"itemStyle": { | |
"color": "#440154" # Viridis color palette (dark purple) | |
}, | |
}, | |
{ | |
"name": "Error Bars", | |
"type": "errorbar", | |
"data": [ | |
[val - err, val + err] for val, err in zip(values, error_bars) | |
], | |
"itemStyle": {"color": "#1f77b4"}, | |
}, | |
], | |
} | |
return options | |