scRNA Mouse Erythroid Run
scRNA Mouse Erythroid Quick Run¶
This notebook provides a compact, end-to-end workflow showing how to run STEER on single-cell RNA velocity data from the mouse erythroid lineage. The default setup uses
training_mode="full"so users can see the standard two-stage STEER pipeline on a real scRNA-seq dataset.If GPU memory is limited,
velo_batch_sizeis the main parameter to reduce memory usage during the kinetics-learning stage. The full run can take substantial GPU time, so the"fast"profile is useful for quick checks before a longer run.
All main-figure results associated with this project are available on Zenodo: 10.5281/zenodo.18713189.
Input¶
The recommended input is an AnnData object containing at least spliced and unspliced layers. Optional metadata such as cell types or developmental stages can be kept in adata.obs for downstream interpretation, but spatial coordinates are not required for this single-cell workflow.
import os
import random
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvelo as scv
import seaborn as sns
import torch
import torch.backends.cudnn as cudnn
import steer
from steer.prior.prior import PriorInferenceManager
warnings.filterwarnings("ignore")
%matplotlib inline
print("=== STEER Quick Start Environment ===")
print(f"PyTorch version: {torch.__version__}")
print(f"Scanpy version: {sc.__version__}")
print(f"scVelo version: {scv.__version__}")
print(f"STEER version: {getattr(steer, '__version__', 'dev build')}")
print("====================================")
=== STEER Quick Start Environment === PyTorch version: 2.8.0+cu128 Scanpy version: 1.10.2 scVelo version: 0.3.3 STEER version: 2.2.2 ====================================
Config¶
- This section includes only the settings that users are most likely to modify. The current default setup uses
training_mode="full", which matches the configuration used to generate the results shown in this notebook.
"full": uses the default STEER training arguments without overrides and runs the complete training schedule."fast": uses a lightweight, quickstart-style parameter set for faster demos and iteration.For advanced usage, you can further customize the
"fast"setting inget_stage1_training_overridesandget_stage2_training_overrides. For example, you may remove theepochsandpatiencesettings to fall back to the same defaults as"full"mode, while still adjustingvelo_batch_sizeto reduce GPU memory usage. This can provide more thorough training than the default"fast"mode while requiring less GPU memory than"full"mode.
- Only the most important settings are exposed here:
data_name/input_dir/result_dir: control input and output locationsFor a first run on a new dataset, it is usually enough to update the data path and
training_mode.
class Config:
# Data
data_name = "erythroid_lineage"
input_dir = "/nvme/users/liuzhy/Review_Files/8_MouseGastrulation/erythroid_lineage_data/"
result_dir = "./results_fast"
seed = 618
# Core model and graph settings
expert = 10 # Number of experts. You can adjust this based on biological prior knowledge
# or an estimated cluster number (see quickstart.ipynb).
smooth_neigh = 100 # Number of neighbors used for smoothing.
# Usually 50-100 works well; increase for noisier data.
# Training profile
training_mode = "fast" # Choose from "full" or "fast".
finetune_epochs = 5000 # Only used when training_mode == "fast".
# Advanced settings
corr_mode = "u" # Use unspliced counts for temporal supervision.
# If time inference is unsatisfactory and unspliced counts are low,
# try setting this to "s" and rerun.
neighbor_metric = "cosine" # Neighbor metric. You can also use Euclidean distance.
use_us = True # Use both unspliced and spliced counts as input features.
use_filter = True # Filter genes before kinetics learning to keep informative genes.
filter_gene_number = 1000 # Number of genes kept when filtering is enabled.
# Prior-inference settings for single-cell data
fine_method = "none" # "hierarchical"
target_size = 300
direction_base = "expert"# "fine"
cfg = Config()
INPUT_FILE = os.path.join(cfg.input_dir, f"{cfg.data_name}.h5ad")
RESULT_PATH = os.path.join(cfg.result_dir, f"{cfg.data_name}_quickstart")
os.makedirs(RESULT_PATH, exist_ok=True)
def get_stage1_training_overrides(mode: str) -> dict:
if mode == "full":
return {}
if mode == "fast":
return {
"expert_mode": "slim",
"pretrain_epochs": 500,
"cluster_epochs": 200,
}
raise ValueError(f"Unsupported training_mode: {mode}")
def get_stage2_training_overrides(mode: str, finetune_epochs=None) -> dict:
if mode == "full":
return {}
if mode == "fast":
if finetune_epochs is None:
raise ValueError("cfg.finetune_epochs must be set when training_mode='fast'.")
return {
"expert_mode": "slim",
"pretrain_epochs": 500,
"cluster_epochs": 200,
"velo_batch_size": 2048,
"MIN_IMPRO": 0.001,
"PATIENCE": 500,
"num_epochs": finetune_epochs,
}
raise ValueError(f"Unsupported training_mode: {mode}")
print(f"Input file: {INPUT_FILE}")
print(f"Output directory: {RESULT_PATH}")
print(f"Training mode: {cfg.training_mode}")
Input file: /nvme/users/liuzhy/Review_Files/8_MouseGastrulation/erythroid_lineage_data/erythroid_lineage.h5ad Output directory: ./results_fast/erythroid_lineage_quickstart Training mode: fast
Seed And Device¶
This cell fixes the random seed for reproducibility and automatically selects GPU or CPU.
def setup_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
setup_seed(cfg.seed)
# you need to use your device number
if torch.cuda.is_available():
device = torch.device("cuda:1")
print(f"Using device: {device}")
Using device: cuda:1
Load Data¶
Load the .h5ad file and check that the required layers are present. When users switch to their own data, this is usually the first place where formatting issues appear.
adata = sc.read_h5ad(INPUT_FILE)
if "X_pca" in adata.obsm:
del adata.obsm["X_pca"]
if "celltype1" in adata.obs and "celltype" not in adata.obs:
adata.obs["celltype"] = adata.obs["celltype1"]
print(adata)
print("\nAvailable layers:", list(adata.layers.keys()))
print("Available obsm keys:", list(adata.obsm.keys()))
print(f"Number of cells: {adata.n_obs}")
print(f"Number of genes: {adata.n_vars}")
required_layers = ["spliced", "unspliced"]
missing_layers = [layer for layer in required_layers if layer not in adata.layers]
if missing_layers:
raise ValueError(f"Missing required layers: {missing_layers}")
if "celltype" not in adata.obs:
print("Warning: adata.obs['celltype'] not found. Downstream plots will need another annotation key.")
AnnData object with n_obs × n_vars = 9815 × 53801
obs: 'sample', 'stage', 'sequencing.batch', 'theiler', 'celltype'
var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'MURK_gene', 'Δm', 'scaled Δm'
uns: 'celltype_colors'
obsm: 'X_umap'
layers: 'spliced', 'unspliced'
Available layers: ['spliced', 'unspliced']
Available obsm keys: ['X_umap']
Number of cells: 9815
Number of genes: 53801
adata
AnnData object with n_obs × n_vars = 9815 × 53801
obs: 'sample', 'stage', 'sequencing.batch', 'theiler', 'celltype'
var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'MURK_gene', 'Δm', 'scaled Δm'
uns: 'celltype_colors'
obsm: 'X_umap'
layers: 'spliced', 'unspliced'
Preprocess¶
Run the basic normalization step and build the dataframe, adjacency matrix, and processed AnnData object required by STEER.
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)
df, adjacency_matrix, adata = steer.preprocess_anndata(
adata,
npc=30,
NUM_AD_NEIGH=30,
SMOOTH_NEIGH=cfg.smooth_neigh,
moments_adj=True,
neighbor_metric=cfg.neighbor_metric,
use_us=cfg.use_us,
)
print(df.head())
print("\nAdjacency matrix shape:", adjacency_matrix.shape)
Filtered out 47456 genes that are detected 20 counts (shared). Normalized count data: X, spliced, unspliced. Extracted 2000 highly variable genes. Logarithmized X.
/nvme/users/liuzhy/miniconda3/envs/steer_dev/lib/python3.10/site-packages/scvelo/preprocessing/utils.py:705: DeprecationWarning: `log1p` is deprecated since scVelo v0.3.0 and will be removed in a future version. Please use `log1p` from `scanpy.pp` instead. log1p(adata)
computing moments based on connectivities
finished (0:00:03) --> added
'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
cellID gene_name unsplice splice orig_unsplice orig_splice \
0 AAAGATCTCTCGAA Arfgef1 0.317008 0.599340 0.0 0.000000
1 AATCTCACTGCTTT Arfgef1 0.498130 0.691799 0.0 1.778597
2 AATGGCTGAAGATG Arfgef1 0.547252 0.733817 0.0 1.034375
3 ACACATCTGTCAAC Arfgef1 0.244253 0.499608 0.0 0.000000
4 ACGACAACTGGAGG Arfgef1 0.319506 0.592695 0.0 0.000000
Mu Ms
0 0.066035 0.201642
1 0.103764 0.232749
2 0.113997 0.246885
3 0.050880 0.168088
4 0.066556 0.199406
Adjacency matrix shape: (9815, 9815)
Expert Number¶
The expert number used in this workflow was originally chosen from an estimation step rather than fixed arbitrarily. For reproduce the results, cfg.expert is directly set to 10, and the optional helper below can be used to re-estimate the expert number directly from the processed embedding before training.
cfg.expert = 5
Build PyG Input¶
Package the current dataset into the PyTorch Geometric objects required by STEER training.
dataset = steer.preload_datasets_all_genes_anndata(df=df, MODEL_MODE="pretrain", adata=adata)
pyg_data = steer.create_pyg_data(dataset, adjacency_matrix, normalize=True)
print(pyg_data)
Data(x=[9815, 4000], edge_index=[2, 373875], type_features=[9815, 2000], orig_features=[9815, 4000], cell_ids=[9815], adj=[9815, 9815])
Stage 1 Training¶
Train the first-stage model to obtain the initial representation and cluster structure.
fullmode: do not override training epochs or related controls, so STEER uses its built-in defaultsfastmode: apply a lightweight quickstart-style configuration for a faster tutorial run
stage1_kwargs = dict(
device=device,
device2=device,
pyg_data=pyg_data,
MODEL_MODE="pretrain",
adata=adata,
NUM_LOSS_NEIGH=30,
corr_mode=cfg.corr_mode,
max_n_cluster=cfg.expert,
path=RESULT_PATH,
)
stage1_kwargs.update(get_stage1_training_overrides(cfg.training_mode))
result_adata = steer.model_training_share_neighbor_adata(**stage1_kwargs)
Epoch 0, Loss 0.12326658517122269 Epoch 50, Loss 0.06401439011096954 Epoch 100, Loss 0.05739820376038551 Epoch 150, Loss 0.051992498338222504 Epoch 200, Loss 0.047516319900751114 Epoch 250, Loss 0.043695274740457535 Epoch 300, Loss 0.04043680429458618 Epoch 350, Loss 0.03751787170767784 Epoch 400, Loss 0.03471369296312332 Epoch 450, Loss 0.032336171716451645 Epoch 500, Loss 1.071047067642212 Epoch 550, Loss 0.09146970510482788 Epoch 600, Loss 0.07198427617549896 Epoch 650, Loss 0.0666075050830841
Optional mclust¶
If R and rpy2 are available, this step runs mclust once for optional clustering refinement. Otherwise, it is skipped automatically.
try:
result_adata = steer.mclust_R(result_adata, num_cluster=cfg.expert)
print("Initial clustering with mclust completed.")
except Exception as e:
print("mclust step skipped.")
print(f"Reason: {e}")
torch.cuda.empty_cache()
R[write to console]: __ __
____ ___ _____/ /_ _______/ /_
/ __ `__ \/ ___/ / / / / ___/ __/
/ / / / / / /__/ / /_/ (__ ) /_
/_/ /_/ /_/\___/_/\__,_/____/\__/ version 6.1.1
Type 'citation("mclust")' for citing this R package in publications.
/nvme/users/liuzhy/miniconda3/envs/steer_dev/lib/python3.10/site-packages/rpy2/robjects/numpy2ri.py:241: DeprecationWarning: The global conversion available with activate() is deprecated and will be removed in the next major release. Use a local converter.
warnings.warn('The global conversion available with activate() '
fitting ... |======================================================================| 100% Initial clustering with mclust completed.
Prior Inference¶
Infer the kinetic prior from the first-stage result for use in the downstream kinetic-learning stage.
prior_manager = PriorInferenceManager(result_adata, df, RESULT_PATH, seed=cfg.seed)
prior_manager.task1_define_fine_clusters(method=cfg.fine_method, target_size=cfg.target_size)
prior_manager.task2_filter_genes(based_on="expert", keep_ngene=cfg.filter_gene_number, use_filter=cfg.use_filter)
prior_manager.task3_calc_convexity(based_on=cfg.direction_base)
result_adata, df_updated, fine_clus_vec_np = prior_manager.finalize_for_training()
print("Prior inference completed.")
--- Task 1: Generating Fine Clusters (Method: none) --- -> Method is 'none', copying Expert clusters to Fine clusters. --- Task 2: Filtering Genes (Based on: EXPERT, Keep: 1000) --- --- Task 3: Calculating Direction (Based on: EXPERT) --- Starting parallel processing with n_jobs=-1...
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers. [Parallel(n_jobs=-1)]: Done 32 tasks | elapsed: 11.8s [Parallel(n_jobs=-1)]: Done 194 tasks | elapsed: 19.2s [Parallel(n_jobs=-1)]: Done 392 tasks | elapsed: 20.2s [Parallel(n_jobs=-1)]: Done 626 tasks | elapsed: 20.8s [Parallel(n_jobs=-1)]: Done 896 tasks | elapsed: 21.4s [Parallel(n_jobs=-1)]: Done 1202 tasks | elapsed: 21.9s [Parallel(n_jobs=-1)]: Done 1544 tasks | elapsed: 22.4s [Parallel(n_jobs=-1)]: Done 2000 out of 2000 | elapsed: 23.1s finished
Aggregating results... Done. --- Finalizing: Restoring Expert labels & Preparing Fine Cluster Vector --- Prior inference completed.
Velocity Dataset¶
Subset to velocity genes and rebuild the data objects required for the second training stage.
prior_adata = result_adata.copy()
raw_adata = sc.read_h5ad(INPUT_FILE)
if "X_pca" in raw_adata.obsm:
del raw_adata.obsm["X_pca"]
if "celltype1" in raw_adata.obs and "celltype" not in raw_adata.obs:
raw_adata.obs["celltype"] = raw_adata.obs["celltype1"]
scv.pp.filter_and_normalize(
raw_adata,
min_shared_counts=20,
n_top_genes=2000,
)
assert all(prior_adata.obs_names == raw_adata.obs_names), "Observation names are not aligned!"
assert all(prior_adata.var_names == raw_adata.var_names), "Variable names are not aligned!"
raw_adata.layers["pred_cell_type"] = prior_adata.layers["pred_cell_type"]
raw_adata.obsm["X_pre_embed"] = prior_adata.obsm["X_pre_embed"]
raw_adata.obs["pred_cluster"] = prior_adata.obs["pred_cluster"].astype(int)
velo_adata = raw_adata[:, prior_adata.var["is_velocity_gene"]].copy()
df_fine, adjacency_matrix_fine, velo_adata = steer.preprocess_anndata(
velo_adata,
npc=30,
NUM_AD_NEIGH=30,
SMOOTH_NEIGH=cfg.smooth_neigh,
moments_adj=True,
neighbor_metric=cfg.neighbor_metric,
use_us=cfg.use_us,
)
dataset_fine = steer.preload_datasets_all_genes_anndata(
df=df_fine,
MODEL_MODE="whole",
adata=velo_adata,
)
pyg_data_fine = steer.create_pyg_data(dataset_fine, adjacency_matrix_fine, normalize=True)
pyg_data_fine.fine_clus_vec = torch.tensor(fine_clus_vec_np, dtype=torch.long, device=device)
print(velo_adata)
print("Velocity-gene subset shape:", velo_adata.shape)
Filtered out 47456 genes that are detected 20 counts (shared). Normalized count data: X, spliced, unspliced. Extracted 2000 highly variable genes. Logarithmized X.
/nvme/users/liuzhy/miniconda3/envs/steer_dev/lib/python3.10/site-packages/scvelo/preprocessing/utils.py:705: DeprecationWarning: `log1p` is deprecated since scVelo v0.3.0 and will be removed in a future version. Please use `log1p` from `scanpy.pp` instead. log1p(adata)
computing moments based on connectivities
finished (0:00:01) --> added
'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
AnnData object with n_obs × n_vars = 9815 × 1000
obs: 'sample', 'stage', 'sequencing.batch', 'theiler', 'celltype', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts', 'pred_cluster'
var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'MURK_gene', 'Δm', 'scaled Δm', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'
uns: 'celltype_colors', 'log1p', 'neighbors'
obsm: 'X_umap', 'X_pre_embed', 'X_pca_combined', 'X_pca_moments'
layers: 'spliced', 'unspliced', 'pred_cell_type', 'Ms', 'Mu', 'scale_Mu', 'scale_Ms'
obsp: 'distances', 'connectivities'
Velocity-gene subset shape: (9815, 1000)
Kinetic Learning¶
Run the main kinetic-learning stage using the prior inferred above.
fullmode: preserve the behavior of your original run by not passing extra training-control arguments and therefore using STEER defaultsfastmode: use a lightweight quickstart-style configuration, includingexpert_mode="slim", shorter pretraining/clustering,velo_batch_size=512, and more aggressive early stopping
stage2_kwargs = dict(
device=device,
device2=device,
pyg_data=pyg_data_fine,
MODEL_MODE="whole",
adata=velo_adata,
NUM_LOSS_NEIGH=30,
max_n_cluster=cfg.expert,
corr_mode=cfg.corr_mode,
path=RESULT_PATH,
)
stage2_kwargs.update(
get_stage2_training_overrides(
cfg.training_mode,
getattr(cfg, "finetune_epochs", None),
)
)
velo_adata = steer.model_training_share_neighbor_adata(**stage2_kwargs)
print("Kinetic-learning stage completed.")
Using Fine Cluster Vector for Correlation Loss. Epoch 0, Loss 0.14109943807125092 Epoch 50, Loss 0.08040767163038254 Epoch 100, Loss 0.0727798119187355 Epoch 150, Loss 0.0665828213095665 Epoch 200, Loss 0.06125184893608093 Epoch 250, Loss 0.05659956857562065 Epoch 300, Loss 0.052438072860240936 Epoch 350, Loss 0.04880218207836151 Epoch 400, Loss 0.045311491936445236 Epoch 450, Loss 0.042271632701158524 Epoch 500, Loss 1.08180832862854 Epoch 550, Loss 0.10931762307882309 Epoch 600, Loss 0.08053947985172272 Epoch 650, Loss 0.07452398538589478 Epoch 700, Loss 0.0709378719329834 Epoch 750, Loss 1.0369415283203125 GATE: 0.03271901234984398, Cluster: 0.03608018159866333, Time_cor: 0.9533256888389587, Time_smooth: 0.014816690236330032 Epoch 800, Loss 1.0325865745544434 GATE: 0.030481768772006035, Cluster: 0.03579676151275635, Time_cor: 0.9523569941520691, Time_smooth: 0.013951120898127556 Epoch 850, Loss 1.0295418500900269 GATE: 0.02857108786702156, Cluster: 0.03580904006958008, Time_cor: 0.9519582986831665, Time_smooth: 0.013203334994614124 Epoch 900, Loss 1.0272172689437866 GATE: 0.02684497833251953, Cluster: 0.03601807355880737, Time_cor: 0.9518389105796814, Time_smooth: 0.012515307404100895 Epoch 950, Loss 1.292514681816101 GATE: 0.025294823572039604, Cluster: 0.03572434186935425, Velo: 0.2757757008075714, Time_cor: 0.9437475204467773, Time_smooth: 0.011972339823842049 Epoch 1000, Loss 1.2794350385665894 GATE: 0.023939985781908035, Cluster: 0.03722625970840454, Velo: 0.2713790237903595, Time_cor: 0.9354447722434998, Time_smooth: 0.011444964446127415 Epoch 1050, Loss 1.2664815187454224 GATE: 0.022752147167921066, Cluster: 0.035604894161224365, Velo: 0.27082088589668274, Time_cor: 0.9261539578437805, Time_smooth: 0.01114966906607151 Epoch 1100, Loss 1.2526360750198364 GATE: 0.02142851985991001, Cluster: 0.03566098213195801, Velo: 0.26809173822402954, Time_cor: 0.9169061779975891, Time_smooth: 0.01054866798222065 Epoch 1150, Loss 1.1558679342269897 GATE: 0.020799832418560982, Cluster: 0.035542428493499756, Velo: 0.1774023473262787, Time_cor: 0.9118537306785583, Time_smooth: 0.010269665159285069 Epoch 1200, Loss 1.1309895515441895 GATE: 0.020031284540891647, Cluster: 0.03561359643936157, Velo: 0.1609746664762497, Time_cor: 0.9043165445327759, Time_smooth: 0.010053353384137154 Epoch 1250, Loss 1.125893473625183 GATE: 0.019720077514648438, Cluster: 0.03553992509841919, Velo: 0.15845948457717896, Time_cor: 0.9022969603538513, Time_smooth: 0.009877052158117294 Epoch 1300, Loss 1.1091359853744507 GATE: 0.018981236964464188, Cluster: 0.035930335521698, Velo: 0.1510349065065384, Time_cor: 0.8936960101127625, Time_smooth: 0.009493498131632805 Epoch 1350, Loss 1.0973705053329468 GATE: 0.018225640058517456, Cluster: 0.035554468631744385, Velo: 0.14749284088611603, Time_cor: 0.8870719075202942, Time_smooth: 0.009025665931403637 Epoch 1400, Loss 1.0960112810134888 GATE: 0.017862854525446892, Cluster: 0.035519957542419434, Velo: 0.14989641308784485, Time_cor: 0.8839266896247864, Time_smooth: 0.008805341087281704 Epoch 1450, Loss 1.0910667181015015 GATE: 0.017772823572158813, Cluster: 0.03551077842712402, Velo: 0.1457340568304062, Time_cor: 0.8833015561103821, Time_smooth: 0.008747444488108158 Epoch 1500, Loss 1.0809119939804077 GATE: 0.017364220693707466, Cluster: 0.03594863414764404, Velo: 0.1441287398338318, Time_cor: 0.8747639060020447, Time_smooth: 0.008706476539373398 Epoch 1550, Loss 1.069075584411621 GATE: 0.0163844246417284, Cluster: 0.03556656837463379, Velo: 0.14279714226722717, Time_cor: 0.866215169429779, Time_smooth: 0.008112302981317043 Epoch 1600, Loss 1.0572261810302734 GATE: 0.015795046463608742, Cluster: 0.03553164005279541, Velo: 0.13903343677520752, Time_cor: 0.8590683937072754, Time_smooth: 0.0077977669425308704 Epoch 1650, Loss 1.0543986558914185 GATE: 0.01528978906571865, Cluster: 0.03548640012741089, Velo: 0.14247997105121613, Time_cor: 0.8536363840103149, Time_smooth: 0.007506120949983597 Epoch 1700, Loss 1.045913815498352 GATE: 0.01500642392784357, Cluster: 0.035484910011291504, Velo: 0.13817471265792847, Time_cor: 0.8499493598937988, Time_smooth: 0.00729833310469985 Epoch 1750, Loss 1.0395641326904297 GATE: 0.014822980388998985, Cluster: 0.03547435998916626, Velo: 0.13436169922351837, Time_cor: 0.8478351831436157, Time_smooth: 0.0070698875933885574 Epoch 1800, Loss 1.0406126976013184 GATE: 0.014715846627950668, Cluster: 0.035462260246276855, Velo: 0.1365250200033188, Time_cor: 0.8469399809837341, Time_smooth: 0.00696960836648941 Epoch 1850, Loss 1.038997769355774 GATE: 0.014695823192596436, Cluster: 0.03545719385147095, Velo: 0.13516856729984283, Time_cor: 0.8467268943786621, Time_smooth: 0.006949255242943764 Epoch 1900, Loss 1.0346438884735107 GATE: 0.01439008116722107, Cluster: 0.035888612270355225, Velo: 0.13869230449199677, Time_cor: 0.8385389447212219, Time_smooth: 0.0071339854039251804 Epoch 1950, Loss 1.0172412395477295 GATE: 0.013712665997445583, Cluster: 0.03553551435470581, Velo: 0.13105349242687225, Time_cor: 0.8302122354507446, Time_smooth: 0.006727301049977541 Epoch 2000, Loss 1.0127034187316895 GATE: 0.01333946269005537, Cluster: 0.03552645444869995, Velo: 0.13472315669059753, Time_cor: 0.8224897384643555, Time_smooth: 0.006624638568609953 Epoch 2050, Loss 1.0049306154251099 GATE: 0.012749305926263332, Cluster: 0.035533785820007324, Velo: 0.13550423085689545, Time_cor: 0.8151803016662598, Time_smooth: 0.005962909199297428 Epoch 2100, Loss 0.9960180521011353 GATE: 0.012277310714125633, Cluster: 0.03553658723831177, Velo: 0.13402679562568665, Time_cor: 0.8086315393447876, Time_smooth: 0.005545863416045904 Epoch 2150, Loss 0.9865667819976807 GATE: 0.011948658153414726, Cluster: 0.03546738624572754, Velo: 0.13091211020946503, Time_cor: 0.8028170466423035, Time_smooth: 0.005421570967882872 Epoch 2200, Loss 0.9817181825637817 GATE: 0.011586485430598259, Cluster: 0.03544694185256958, Velo: 0.13166312873363495, Time_cor: 0.7978119850158691, Time_smooth: 0.005209631752222776 Epoch 2250, Loss 0.979472815990448 GATE: 0.01132783479988575, Cluster: 0.035415709018707275, Velo: 0.13415344059467316, Time_cor: 0.7935397624969482, Time_smooth: 0.00503606628626585 Epoch 2300, Loss 0.9730192422866821 GATE: 0.011139970272779465, Cluster: 0.035383760929107666, Velo: 0.1315372884273529, Time_cor: 0.7901217937469482, Time_smooth: 0.00483643589541316 Epoch 2350, Loss 0.9661833643913269 GATE: 0.01091655995696783, Cluster: 0.03535866737365723, Velo: 0.1277928501367569, Time_cor: 0.7873796820640564, Time_smooth: 0.004735573194921017 Epoch 2400, Loss 0.9657718539237976 GATE: 0.010839035734534264, Cluster: 0.03536158800125122, Velo: 0.1295243203639984, Time_cor: 0.7853979468345642, Time_smooth: 0.004648955073207617 Epoch 2450, Loss 0.963250458240509 GATE: 0.01068950816988945, Cluster: 0.03533369302749634, Velo: 0.12872150540351868, Time_cor: 0.7839569449424744, Time_smooth: 0.0045487661845982075 Epoch 2500, Loss 0.958224356174469 GATE: 0.010621265508234501, Cluster: 0.03532290458679199, Velo: 0.12471798062324524, Time_cor: 0.7830479741096497, Time_smooth: 0.004514242056757212 Epoch 2550, Loss 0.9571680426597595 GATE: 0.010581396520137787, Cluster: 0.03532141447067261, Velo: 0.12425254285335541, Time_cor: 0.782532811164856, Time_smooth: 0.0044799065217375755 Epoch 2600, Loss 0.9588713049888611 GATE: 0.010565049014985561, Cluster: 0.03531956672668457, Velo: 0.12622462213039398, Time_cor: 0.7822934985160828, Time_smooth: 0.004468538332730532 Epoch 2650, Loss 0.959942102432251 GATE: 0.01055548619478941, Cluster: 0.03531849384307861, Velo: 0.12741543352603912, Time_cor: 0.7821935415267944, Time_smooth: 0.004459156654775143 Epoch 2700, Loss 0.9595556855201721 GATE: 0.01151309534907341, Cluster: 0.03546798229217529, Velo: 0.1293787658214569, Time_cor: 0.7764725685119629, Time_smooth: 0.006723302416503429 Epoch 2750, Loss 0.9492339491844177 GATE: 0.010415156371891499, Cluster: 0.035368919372558594, Velo: 0.1284705400466919, Time_cor: 0.7691971063613892, Time_smooth: 0.005782263353466988 Epoch 2800, Loss 0.9394678473472595 GATE: 0.009940660558640957, Cluster: 0.035356104373931885, Velo: 0.12632668018341064, Time_cor: 0.7625665068626404, Time_smooth: 0.00527793588116765 Epoch 2850, Loss 0.9458891749382019 GATE: 0.010901173576712608, Cluster: 0.03637421131134033, Velo: 0.13060104846954346, Time_cor: 0.756642758846283, Time_smooth: 0.011370023712515831 Epoch 2900, Loss 0.9357679486274719 GATE: 0.009776478633284569, Cluster: 0.0353388786315918, Velo: 0.1305444985628128, Time_cor: 0.7503698468208313, Time_smooth: 0.009738280437886715 Epoch 2950, Loss 0.9246451258659363 GATE: 0.009363371878862381, Cluster: 0.03557872772216797, Velo: 0.12620827555656433, Time_cor: 0.7447048425674438, Time_smooth: 0.008789892308413982 Epoch 3000, Loss 0.916650652885437 GATE: 0.008823797106742859, Cluster: 0.03537815809249878, Velo: 0.12518121302127838, Time_cor: 0.739386260509491, Time_smooth: 0.007881241850554943 Epoch 3050, Loss 0.9141762852668762 GATE: 0.008566491305828094, Cluster: 0.03528773784637451, Velo: 0.12849532067775726, Time_cor: 0.7344586253166199, Time_smooth: 0.007368121761828661 Epoch 3100, Loss 0.9077210426330566 GATE: 0.008319217711687088, Cluster: 0.03546780347824097, Velo: 0.12712673842906952, Time_cor: 0.7299159169197083, Time_smooth: 0.006891362834721804 Epoch 3150, Loss 0.9023134708404541 GATE: 0.008110325783491135, Cluster: 0.03544551134109497, Velo: 0.12669342756271362, Time_cor: 0.7255601286888123, Time_smooth: 0.006504059303551912 Epoch 3200, Loss 0.8971222043037415 GATE: 0.00795813649892807, Cluster: 0.0353546142578125, Velo: 0.1260329782962799, Time_cor: 0.7216439247131348, Time_smooth: 0.006132567301392555 Epoch 3250, Loss 0.8931406736373901 GATE: 0.007819355465471745, Cluster: 0.03536856174468994, Velo: 0.12600265443325043, Time_cor: 0.7180902361869812, Time_smooth: 0.0058598811738193035 Epoch 3300, Loss 0.8890437483787537 GATE: 0.007733501028269529, Cluster: 0.03526508808135986, Velo: 0.1256316602230072, Time_cor: 0.7148054838180542, Time_smooth: 0.005608047358691692 Epoch 3350, Loss 0.8830109238624573 GATE: 0.007533684838563204, Cluster: 0.03526872396469116, Velo: 0.12300821393728256, Time_cor: 0.7118157744407654, Time_smooth: 0.005384492687880993 Epoch 3400, Loss 0.8790327310562134 GATE: 0.00740026356652379, Cluster: 0.035327136516571045, Velo: 0.12192286550998688, Time_cor: 0.70916348695755, Time_smooth: 0.005219000857323408 Epoch 3450, Loss 0.8792266249656677 GATE: 0.007470160722732544, Cluster: 0.03537529706954956, Velo: 0.1244133710861206, Time_cor: 0.7068673968315125, Time_smooth: 0.0051004099659621716 Epoch 3500, Loss 0.872403621673584 GATE: 0.007199693936854601, Cluster: 0.03524094820022583, Velo: 0.12034698575735092, Time_cor: 0.7046304941177368, Time_smooth: 0.004985523875802755 Epoch 3550, Loss 0.8719374537467957 GATE: 0.007101850118488073, Cluster: 0.03524857759475708, Velo: 0.12197044491767883, Time_cor: 0.7027565240859985, Time_smooth: 0.0048600174486637115 Epoch 3600, Loss 0.8711535334587097 GATE: 0.0070281317457556725, Cluster: 0.03523343801498413, Velo: 0.12301947921514511, Time_cor: 0.7011203169822693, Time_smooth: 0.004752186127007008 Epoch 3650, Loss 0.8672225475311279 GATE: 0.006960507016628981, Cluster: 0.03521615266799927, Velo: 0.12068920582532883, Time_cor: 0.699708878993988, Time_smooth: 0.004647792782634497 Epoch 3700, Loss 0.870712399482727 GATE: 0.006906755734235048, Cluster: 0.035212039947509766, Velo: 0.12555374205112457, Time_cor: 0.6984842419624329, Time_smooth: 0.004555596504360437 Epoch 3750, Loss 0.8653516173362732 GATE: 0.006836210377514362, Cluster: 0.03522378206253052, Velo: 0.1213570237159729, Time_cor: 0.6974555850028992, Time_smooth: 0.004479007329791784 Epoch 3800, Loss 0.8661607503890991 GATE: 0.006812657695263624, Cluster: 0.035222411155700684, Velo: 0.12309615314006805, Time_cor: 0.6966065764427185, Time_smooth: 0.004422987345606089 Epoch 3850, Loss 0.8645286560058594 GATE: 0.006793119478970766, Cluster: 0.03523498773574829, Velo: 0.12223473191261292, Time_cor: 0.6959128975868225, Time_smooth: 0.00435290951281786 Epoch 3900, Loss 0.8602012395858765 GATE: 0.006737581919878721, Cluster: 0.035210609436035156, Velo: 0.11860556155443192, Time_cor: 0.6953423023223877, Time_smooth: 0.0043051778338849545 Epoch 3950, Loss 0.8626157641410828 GATE: 0.006707982625812292, Cluster: 0.03521150350570679, Velo: 0.12154485285282135, Time_cor: 0.6949059367179871, Time_smooth: 0.004245446994900703 Epoch 4000, Loss 0.8617814779281616 GATE: 0.006687719840556383, Cluster: 0.035206615924835205, Velo: 0.12108967453241348, Time_cor: 0.6945785880088806, Time_smooth: 0.00421885447576642 Epoch 4050, Loss 0.8628175258636475 GATE: 0.006674243602901697, Cluster: 0.035209059715270996, Velo: 0.12239226698875427, Time_cor: 0.6943460702896118, Time_smooth: 0.004195846151560545 Epoch 4100, Loss 0.8652332425117493 GATE: 0.006665047723799944, Cluster: 0.03520578145980835, Velo: 0.12499677389860153, Time_cor: 0.6941866874694824, Time_smooth: 0.004178931936621666 Epoch 4150, Loss 0.8636197447776794 GATE: 0.006680463440716267, Cluster: 0.03520625829696655, Velo: 0.12345963716506958, Time_cor: 0.6940969824790955, Time_smooth: 0.004176464397460222 Epoch 4200, Loss 0.8618668913841248 GATE: 0.006657164078205824, Cluster: 0.03520464897155762, Velo: 0.12181941419839859, Time_cor: 0.6940221786499023, Time_smooth: 0.004163485486060381 Epoch 4250, Loss 0.8621706962585449 GATE: 0.006653263233602047, Cluster: 0.03520435094833374, Velo: 0.12217675894498825, Time_cor: 0.6939778923988342, Time_smooth: 0.004158438183367252 Epoch 4300, Loss 0.8610144853591919 GATE: 0.0075624561868608, Cluster: 0.03553217649459839, Velo: 0.12089953571557999, Time_cor: 0.6910329461097717, Time_smooth: 0.005987387616187334 Epoch 4350, Loss 0.8562915921211243 GATE: 0.006939211394637823, Cluster: 0.03549373149871826, Velo: 0.12121127545833588, Time_cor: 0.6870712041854858, Time_smooth: 0.00557619147002697 Epoch 4400, Loss 0.8538849949836731 GATE: 0.006534209940582514, Cluster: 0.03528624773025513, Velo: 0.12359705567359924, Time_cor: 0.6834689378738403, Time_smooth: 0.00499848322942853 Epoch 4450, Loss 0.8506528735160828 GATE: 0.0063857208006083965, Cluster: 0.035353899002075195, Velo: 0.1240384429693222, Time_cor: 0.6802471280097961, Time_smooth: 0.004627622198313475 Epoch 4500, Loss 0.8449651598930359 GATE: 0.0062003242783248425, Cluster: 0.03527170419692993, Velo: 0.121892549097538, Time_cor: 0.6772424578666687, Time_smooth: 0.004358100704848766 Epoch 4550, Loss 0.8414856195449829 GATE: 0.006147367879748344, Cluster: 0.035823822021484375, Velo: 0.12085500359535217, Time_cor: 0.6745058298110962, Time_smooth: 0.0041535962373018265 Epoch 4600, Loss 0.8389448523521423 GATE: 0.006008945871144533, Cluster: 0.03525960445404053, Velo: 0.12160277366638184, Time_cor: 0.6720573306083679, Time_smooth: 0.004016215912997723 Epoch 4650, Loss 0.835629403591156 GATE: 0.005907153245061636, Cluster: 0.03519797325134277, Velo: 0.12110355496406555, Time_cor: 0.669592559337616, Time_smooth: 0.003828191664069891 Epoch 4700, Loss 0.8353386521339417 GATE: 0.0058616516180336475, Cluster: 0.03531503677368164, Velo: 0.12283740937709808, Time_cor: 0.667580246925354, Time_smooth: 0.003744277870282531 Epoch 4750, Loss 0.8313387036323547 GATE: 0.005791990552097559, Cluster: 0.0352858304977417, Velo: 0.12095768004655838, Time_cor: 0.6656663417816162, Time_smooth: 0.0036369203589856625 Epoch 4800, Loss 0.8307786583900452 GATE: 0.0056674666702747345, Cluster: 0.035282790660858154, Velo: 0.12264423072338104, Time_cor: 0.6637043356895447, Time_smooth: 0.0034798108972609043 Epoch 4850, Loss 0.8292815685272217 GATE: 0.005608834326267242, Cluster: 0.03532284498214722, Velo: 0.12290289998054504, Time_cor: 0.6620429158210754, Time_smooth: 0.003404090413823724 Epoch 4900, Loss 0.8270974159240723 GATE: 0.005509241484105587, Cluster: 0.035174012184143066, Velo: 0.1225314736366272, Time_cor: 0.660561740398407, Time_smooth: 0.0033209254033863544 Epoch 4950, Loss 0.8260904550552368 GATE: 0.005498659797012806, Cluster: 0.035250067710876465, Velo: 0.1228942945599556, Time_cor: 0.6591810584068298, Time_smooth: 0.0032663890160620213 Loading best model state from memory... Kinetic-learning stage completed.
Velocity Visualization¶
Build the velocity graph and inspect the learned flow field in latent space and on the expression UMAP.
result_adata = velo_adata.copy()
sc.pp.neighbors(result_adata, n_neighbors=30, use_rep="X_para_t", key_added="para_t_neighbors")
temp_adata = sc.tl.umap(result_adata, neighbors_key="para_t_neighbors", copy=True)
result_adata.obsm["X_umap_para_t_embed"] = temp_adata.obsm["X_umap"]
sc.pp.neighbors(result_adata, n_neighbors=30, use_rep="X_para", key_added="para_neighbors")
temp_adata = sc.tl.umap(result_adata, neighbors_key="para_neighbors", copy=True)
result_adata.obsm["X_umap_para_embed"] = temp_adata.obsm["X_umap"]
sc.pp.neighbors(result_adata, n_neighbors=30, use_rep="X_refine_embed_t", key_added="refine_embed_t_neighbors")
temp_adata = sc.tl.umap(result_adata, neighbors_key="refine_embed_t_neighbors", copy=True)
result_adata.obsm["X_umap_refine_embed_t"] = temp_adata.obsm["X_umap"]
sc.pp.neighbors(result_adata, use_rep="X_refine_embed", n_neighbors=30)
steer.velocity_graph(result_adata, vkey="pred_vs_norm", xkey="model_Ms")
print("obs columns:")
print(sorted(result_adata.obs.columns.tolist()))
print("\nobsm keys:")
print(sorted(result_adata.obsm.keys()))
print("\nlayers:")
print(sorted(result_adata.layers.keys()))
computing velocity graph (using 1/128 cores)
0%| | 0/9815 [00:00<?, ?cells/s]
finished (0:00:10) --> added
'pred_vs_norm_graph', sparse matrix with cosine correlations (adata.uns)
obs columns:
['Expert', 'Expert Weight', 'Pred Time', 'celltype', 'initial_size', 'initial_size_spliced', 'initial_size_unspliced', 'n_counts', 'pred_vs_norm_self_transition', 'pretrain_cluster', 'sample', 'sequencing.batch', 'stage', 'theiler']
obsm keys:
['X_alpha', 'X_beta', 'X_gamma', 'X_para', 'X_para_t', 'X_pca_combined', 'X_pca_moments', 'X_pre_embed', 'X_refine_embed', 'X_refine_embed_t', 'X_umap', 'X_umap_para_embed', 'X_umap_para_t_embed', 'X_umap_refine_embed_t', 'cluster_matrix']
layers:
['final_recon_s', 'final_recon_u', 'init_regulate_state', 'model_Ms', 'model_Mu', 'orig_s', 'orig_u', 'pred_time_layer', 'pred_vs', 'pred_vs_norm', 'pred_vu', 'pred_vu_norm', 'recon_alpha', 'recon_alpha_norm', 'recon_beta', 'recon_gamma', 'recon_gamma_norm', 'regulate_state', 'scale_Ms', 'scale_Mu']
result_adata
AnnData object with n_obs × n_vars = 9815 × 1000
obs: 'sample', 'stage', 'sequencing.batch', 'theiler', 'celltype', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts', 'pretrain_cluster', 'Expert', 'Expert Weight', 'Pred Time', 'pred_vs_norm_self_transition'
var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'MURK_gene', 'Δm', 'scaled Δm', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'
uns: 'celltype_colors', 'log1p', 'neighbors', 'para_t_neighbors', 'para_neighbors', 'refine_embed_t_neighbors', 'pred_vs_norm_graph', 'pred_vs_norm_graph_neg', 'pred_vs_norm_params'
obsm: 'X_umap', 'X_pre_embed', 'X_pca_combined', 'X_pca_moments', 'X_refine_embed', 'cluster_matrix', 'X_alpha', 'X_beta', 'X_gamma', 'X_para', 'X_para_t', 'X_refine_embed_t', 'X_umap_para_t_embed', 'X_umap_para_embed', 'X_umap_refine_embed_t'
layers: 'scale_Mu', 'scale_Ms', 'recon_alpha', 'recon_beta', 'recon_gamma', 'pred_vu', 'pred_vs', 'pred_vu_norm', 'pred_vs_norm', 'init_regulate_state', 'regulate_state', 'final_recon_s', 'final_recon_u', 'model_Ms', 'model_Mu', 'orig_s', 'orig_u', 'pred_time_layer', 'recon_alpha_norm', 'recon_gamma_norm'
obsp: 'distances', 'connectivities', 'para_t_neighbors_distances', 'para_t_neighbors_connectivities', 'para_neighbors_distances', 'para_neighbors_connectivities', 'refine_embed_t_neighbors_distances', 'refine_embed_t_neighbors_connectivities'
scv.settings.figdir = RESULT_PATH
scv.set_figure_params(style="scvelo", dpi=150, figsize=(5, 4), transparent=True)
color_key = "celltype" if "celltype" in result_adata.obs else "pred_cluster"
scv.pl.velocity_embedding_stream(
result_adata,
basis="X_umap_refine_embed_t",
vkey="pred_vs_norm",
color=[color_key, "Pred Time", "Expert"],
legend_loc="right",
title="Velocity on STEER latent UMAP",
show=True,
save="velo_latent_umap.png",
)
scv.pl.velocity_embedding_stream(
result_adata,
basis="X_umap",
vkey="pred_vs_norm",
color=[color_key, "Pred Time", "Expert"],
alpha=1,
legend_loc="right",
title="Velocity on expression UMAP",
show=True,
save="velo_expression_umap.png",
)
result_adata.write(os.path.join(RESULT_PATH, "final_adata_em.h5ad"))
computing velocity embedding
finished (0:00:01) --> added
'pred_vs_norm_umap_refine_embed_t', embedded velocity vectors (adata.obsm)
saving figure to file ./results_fast/erythroid_lineage_quickstart/scvelo_velo_latent_umap.png
computing velocity embedding
finished (0:00:01) --> added
'pred_vs_norm_umap', embedded velocity vectors (adata.obsm)
saving figure to file ./results_fast/erythroid_lineage_quickstart/scvelo_velo_expression_umap.png